]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/flow_divert.c
xnu-3248.20.55.tar.gz
[apple/xnu.git] / bsd / netinet / flow_divert.c
index 697016c4996a806e54d433e36391091f817c69d4..76e29f8d607a50027fec5e7b0856e9db2c77a13b 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012-2014 Apple Inc. All rights reserved.
+ * Copyright (c) 2012-2015 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  * 
@@ -57,6 +57,7 @@
 #include <netinet/flow_divert.h>
 #include <netinet/flow_divert_proto.h>
 #if INET6
+#include <netinet6/in6_pcb.h>
 #include <netinet6/ip6protosw.h>
 #endif /* INET6 */
 #include <dev/random/randomdev.h>
@@ -149,13 +150,40 @@ static    kern_ctl_ref                            g_flow_divert_kctl_ref                  = NULL;
 
 static struct protosw                          g_flow_divert_in_protosw;
 static struct pr_usrreqs                       g_flow_divert_in_usrreqs;
+static struct protosw                          g_flow_divert_in_udp_protosw;
+static struct pr_usrreqs                       g_flow_divert_in_udp_usrreqs;
 #if INET6
 static struct ip6protosw                       g_flow_divert_in6_protosw;
 static struct pr_usrreqs                       g_flow_divert_in6_usrreqs;
+static struct ip6protosw                       g_flow_divert_in6_udp_protosw;
+static struct pr_usrreqs                       g_flow_divert_in6_udp_usrreqs;
 #endif /* INET6 */
 
 static struct protosw                          *g_tcp_protosw                                  = NULL;
 static struct ip6protosw                       *g_tcp6_protosw                                 = NULL;
+static struct protosw                          *g_udp_protosw                                  = NULL;
+static struct ip6protosw                       *g_udp6_protosw                                 = NULL;
+
+static errno_t
+flow_divert_dup_addr(sa_family_t family, struct sockaddr *addr, struct sockaddr **dup);
+
+static errno_t
+flow_divert_inp_to_sockaddr(const struct inpcb *inp, struct sockaddr **local_socket);
+
+static boolean_t
+flow_divert_is_sockaddr_valid(struct sockaddr *addr);
+
+static int
+flow_divert_append_target_endpoint_tlv(mbuf_t connect_packet, struct sockaddr *toaddr);
+
+struct sockaddr *
+flow_divert_get_buffered_target_address(mbuf_t buffer);
+
+static boolean_t
+flow_divert_has_pcb_local_address(const struct inpcb *inp);
+
+static void
+flow_divert_disconnect_socket(struct socket *so);
 
 static inline int
 flow_divert_pcb_cmp(const struct flow_divert_pcb *pcb_a, const struct flow_divert_pcb *pcb_b)
@@ -210,12 +238,11 @@ flow_divert_pcb_lookup(uint32_t hash, struct flow_divert_group *group)
 static errno_t
 flow_divert_pcb_insert(struct flow_divert_pcb *fd_cb, uint32_t ctl_unit)
 {
-       int                                                     error                                           = 0;
+       errno_t                                                 error                                           = 0;
        struct                                          flow_divert_pcb *exist          = NULL;
        struct flow_divert_group        *group;
        static uint32_t                         g_nextkey                                       = 1;
        static uint32_t                         g_hash_seed                                     = 0;
-       errno_t                                         result                                          = 0;
        int                                                     try_count                                       = 0;
 
        if (ctl_unit == 0 || ctl_unit >= GROUP_COUNT_MAX) {
@@ -277,7 +304,7 @@ flow_divert_pcb_insert(struct flow_divert_pcb *fd_cb, uint32_t ctl_unit)
                FDRETAIN(fd_cb);                /* The group now has a reference */
        } else {
                fd_cb->hash = 0;
-               result = EEXIST;
+               error = EEXIST;
        }
 
        socket_unlock(fd_cb->so, 0);
@@ -286,7 +313,7 @@ done:
        lck_rw_done(&g_flow_divert_group_lck);
        socket_lock(fd_cb->so, 0);
 
-       return result;
+       return error;
 }
 
 static struct flow_divert_pcb *
@@ -371,10 +398,10 @@ flow_divert_packet_init(struct flow_divert_pcb *fd_cb, uint8_t packet_type, mbuf
 }
 
 static int
-flow_divert_packet_append_tlv(mbuf_t packet, uint8_t type, size_t length, const void *value)
+flow_divert_packet_append_tlv(mbuf_t packet, uint8_t type, uint32_t length, const void *value)
 {
-       size_t  net_length      = htonl(length);
-       int             error           = 0;
+       uint32_t        net_length      = htonl(length);
+       int                     error           = 0;
 
        error = mbuf_copyback(packet, mbuf_pkthdr_len(packet), sizeof(type), &type, MBUF_DONTWAIT);
        if (error) {
@@ -400,10 +427,10 @@ flow_divert_packet_append_tlv(mbuf_t packet, uint8_t type, size_t length, const
 static int
 flow_divert_packet_find_tlv(mbuf_t packet, int offset, uint8_t type, int *err, int next)
 {
-       size_t  cursor                  = offset;
-       int             error                   = 0;
-       size_t  curr_length;
-       uint8_t curr_type;
+       size_t          cursor                  = offset;
+       int                     error                   = 0;
+       uint32_t        curr_length;
+       uint8_t         curr_type;
 
        *err = 0;
 
@@ -435,11 +462,11 @@ flow_divert_packet_find_tlv(mbuf_t packet, int offset, uint8_t type, int *err, i
 }
 
 static int
-flow_divert_packet_get_tlv(mbuf_t packet, int offset, uint8_t type, size_t buff_len, void *buff, size_t *val_size)
+flow_divert_packet_get_tlv(mbuf_t packet, int offset, uint8_t type, size_t buff_len, void *buff, uint32_t *val_size)
 {
-       int             error           = 0;
-       size_t  length;
-       int             tlv_offset;
+       int                     error           = 0;
+       uint32_t        length;
+       int                     tlv_offset;
 
        tlv_offset = flow_divert_packet_find_tlv(packet, offset, type, &error, 0);
        if (tlv_offset < 0) {
@@ -778,8 +805,9 @@ flow_divert_trie_insert(struct flow_divert_trie *trie, uint16_t string_start, si
        return current;
 }
 
+#define APPLE_WEBCLIP_ID_PREFIX        "com.apple.webapp"
 static uint16_t
-flow_divert_trie_search(struct flow_divert_trie *trie, const uint8_t *string_bytes)
+flow_divert_trie_search(struct flow_divert_trie *trie, uint8_t *string_bytes)
 {
        uint16_t current = trie->root;
        uint16_t string_idx = 0;
@@ -796,6 +824,10 @@ flow_divert_trie_search(struct flow_divert_trie *trie, const uint8_t *string_byt
                if (node_idx == node_end) {
                        if (string_bytes[string_idx] == '\0') {
                                return current; /* Got an exact match */
+                       } else if (string_idx == strlen(APPLE_WEBCLIP_ID_PREFIX) &&
+                                  0 == strncmp((const char *)string_bytes, APPLE_WEBCLIP_ID_PREFIX, string_idx)) {
+                               string_bytes[string_idx] = '\0'; 
+                               return current; /* Got an apple webclip id prefix match */
                        } else if (TRIE_NODE(trie, current).child_map != NULL_TRIE_IDX) {
                                next = TRIE_CHILD(trie, current, string_bytes[string_idx]);
                        }
@@ -841,7 +873,7 @@ flow_divert_send_packet(struct flow_divert_pcb *fd_cb, mbuf_t packet, Boolean en
 
        if (fd_cb->group == NULL) {
                fd_cb->so->so_error = ECONNABORTED;
-               soisdisconnected(fd_cb->so);
+               flow_divert_disconnect_socket(fd_cb->so);
                return ECONNABORTED;
        }
 
@@ -873,6 +905,7 @@ static int
 flow_divert_send_connect(struct flow_divert_pcb *fd_cb, struct sockaddr *to, mbuf_t connect_packet)
 {
        int                             error                   = 0;
+       int                             flow_type               = 0;
 
        error = flow_divert_packet_append_tlv(connect_packet,
                                              FLOW_DIVERT_TLV_TRAFFIC_CLASS,
@@ -882,6 +915,23 @@ flow_divert_send_connect(struct flow_divert_pcb *fd_cb, struct sockaddr *to, mbu
                goto done;
        }
 
+       if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
+               flow_type = FLOW_DIVERT_FLOW_TYPE_TCP;
+       } else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
+               flow_type = FLOW_DIVERT_FLOW_TYPE_UDP;
+       } else {
+               error = EINVAL;
+               goto done;
+       }
+       error = flow_divert_packet_append_tlv(connect_packet,
+                                             FLOW_DIVERT_TLV_FLOW_TYPE,
+                                             sizeof(flow_type),
+                                             &flow_type);
+
+       if (error) {
+               goto done;
+       }
+
        if (fd_cb->so->so_flags & SOF_DELEGATED) {
                error = flow_divert_packet_append_tlv(connect_packet,
                                                      FLOW_DIVERT_TLV_PID,
@@ -923,33 +973,27 @@ flow_divert_send_connect(struct flow_divert_pcb *fd_cb, struct sockaddr *to, mbu
                fd_cb->connect_token = NULL;
        } else {
                uint32_t ctl_unit = htonl(fd_cb->control_group_unit);
-               int port;
 
                error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_CTL_UNIT, sizeof(ctl_unit), &ctl_unit);
                if (error) {
                        goto done;
                }
 
-               error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_ADDRESS, to->sa_len, to);
-               if (error) {
-                       goto done;
-               }
-
-               if (to->sa_family == AF_INET) {
-                       port = ntohs((satosin(to))->sin_port);
-               }
-#if INET6
-               else {
-                       port = ntohs((satosin6(to))->sin6_port);
-               }
-#endif
-
-               error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_PORT, sizeof(port), &port);
+               error = flow_divert_append_target_endpoint_tlv(connect_packet, to);
                if (error) {
                        goto done;
                }
        }
 
+       if (fd_cb->local_address != NULL) {
+               /* socket is bound. */
+                error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_LOCAL_ADDR,
+                                                       sizeof(struct sockaddr_storage), fd_cb->local_address);
+                if (error) {
+                        goto done;
+                }
+        }
+
        error = flow_divert_send_packet(fd_cb, connect_packet, TRUE);
        if (error) {
                goto done;
@@ -972,7 +1016,7 @@ flow_divert_send_connect_result(struct flow_divert_pcb *fd_cb)
                goto done;
        }
 
-       rbuff_space = sbspace(&fd_cb->so->so_rcv);
+       rbuff_space = fd_cb->so->so_rcv.sb_hiwat;
        if (rbuff_space < 0) {
                rbuff_space = 0;
        }
@@ -992,7 +1036,7 @@ flow_divert_send_connect_result(struct flow_divert_pcb *fd_cb)
 
 done:
        if (error && packet != NULL) {
-               mbuf_free(packet);
+               mbuf_freem(packet);
        }
 
        return error;
@@ -1092,12 +1136,12 @@ flow_divert_send_close_if_needed(struct flow_divert_pcb *fd_cb)
        }
 
        if (flow_divert_tunnel_how_closed(fd_cb) == SHUT_RDWR) {
-               soisdisconnected(fd_cb->so);
+               flow_divert_disconnect_socket(fd_cb->so);
        }
 }
 
 static errno_t
-flow_divert_send_data_packet(struct flow_divert_pcb *fd_cb, mbuf_t data, size_t data_len, Boolean force)
+flow_divert_send_data_packet(struct flow_divert_pcb *fd_cb, mbuf_t data, size_t data_len, struct sockaddr *toaddr, Boolean force)
 {
        mbuf_t  packet;
        mbuf_t  last;
@@ -1109,15 +1153,22 @@ flow_divert_send_data_packet(struct flow_divert_pcb *fd_cb, mbuf_t data, size_t
                return error;
        }
 
+       if (toaddr != NULL) {
+               error = flow_divert_append_target_endpoint_tlv(packet, toaddr);
+               if (error) {
+                       FDLOG(LOG_ERR, fd_cb, "flow_divert_append_target_endpoint_tlv() failed: %d", error);
+                       return error;
+               }
+       }
+
        last = m_last(packet);
        mbuf_setnext(last, data);
        mbuf_pkthdr_adjustlen(packet, data_len);
-
        error = flow_divert_send_packet(fd_cb, packet, force);
 
        if (error) {
                mbuf_setnext(last, NULL);
-               mbuf_free(packet);
+               mbuf_freem(packet);
        } else {
                fd_cb->bytes_sent += data_len;
                flow_divert_add_data_statistics(fd_cb, data_len, TRUE);
@@ -1147,28 +1198,72 @@ flow_divert_send_buffered_data(struct flow_divert_pcb *fd_cb, Boolean force)
                to_send = fd_cb->send_window;
        }
 
-       while (sent < to_send) {
-               mbuf_t  data;
-               size_t  data_len;
+       if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
+               while (sent < to_send) {
+                       mbuf_t  data;
+                       size_t  data_len;
 
-               data_len = to_send - sent;
-               if (data_len > FLOW_DIVERT_CHUNK_SIZE) {
-                       data_len = FLOW_DIVERT_CHUNK_SIZE;
-               }
+                       data_len = to_send - sent;
+                       if (data_len > FLOW_DIVERT_CHUNK_SIZE) {
+                               data_len = FLOW_DIVERT_CHUNK_SIZE;
+                       }
 
-               error = mbuf_copym(buffer, sent, data_len, MBUF_DONTWAIT, &data);
-               if (error) {
-                       FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
-                       break;
-               }
+                       error = mbuf_copym(buffer, sent, data_len, MBUF_DONTWAIT, &data);
+                       if (error) {
+                               FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
+                               break;
+                       }
 
-               error = flow_divert_send_data_packet(fd_cb, data, data_len, force);
-               if (error) {
-                       mbuf_free(data);
-                       break;
-               }
+                       error = flow_divert_send_data_packet(fd_cb, data, data_len, NULL, force);
+                       if (error) {
+                               mbuf_freem(data);
+                               break;
+                       }
 
-               sent += data_len;
+                       sent += data_len;
+               }
+               sbdrop(&fd_cb->so->so_snd, sent);
+               sowwakeup(fd_cb->so);
+       } else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
+               mbuf_t data;
+               mbuf_t m;
+               size_t data_len;
+
+               while(buffer) {
+                       struct sockaddr *toaddr = flow_divert_get_buffered_target_address(buffer);
+
+                       m = buffer;
+                       if (toaddr != NULL) {
+                               /* look for data in the chain */
+                               do {
+                                       m = m->m_next;
+                                       if (m != NULL && m->m_type == MT_DATA) {
+                                               break;
+                                       }
+                               } while(m);
+                               if (m == NULL) {
+                                       /* unexpected */
+                                       FDLOG0(LOG_ERR, fd_cb, "failed to find type MT_DATA in the mbuf chain.");
+                                       goto move_on;
+                               }
+                       }
+                       data_len = mbuf_pkthdr_len(m);
+                       FDLOG(LOG_DEBUG, fd_cb, "mbuf_copym() data_len = %u", data_len);
+                       error = mbuf_copym(m, 0, data_len, MBUF_DONTWAIT, &data);
+                       if (error) {
+                               FDLOG(LOG_ERR, fd_cb, "mbuf_copym failed: %d", error);
+                               break;
+                       }
+                       error = flow_divert_send_data_packet(fd_cb, data, data_len, toaddr, force);
+                       if (error) {
+                               mbuf_freem(data);
+                               break;
+                       }
+                       sent += data_len;
+move_on:
+                       buffer = buffer->m_nextpkt;
+                       (void) sbdroprecord(&(fd_cb->so->so_snd));
+               }
        }
 
        if (sent > 0) {
@@ -1178,19 +1273,14 @@ flow_divert_send_buffered_data(struct flow_divert_pcb *fd_cb, Boolean force)
                } else {
                        fd_cb->send_window = 0;
                }
-               sbdrop(&fd_cb->so->so_snd, sent);
-               sowwakeup(fd_cb->so);
        }
 }
 
 static int
-flow_divert_send_app_data(struct flow_divert_pcb *fd_cb, mbuf_t data)
+flow_divert_send_app_data(struct flow_divert_pcb *fd_cb, mbuf_t data, struct sockaddr *toaddr)
 {
        size_t  to_send         = mbuf_pkthdr_len(data);
-       size_t  sent            = 0;
-       int             error           = 0;
-       mbuf_t  remaining_data  = data;
-       mbuf_t  pkt_data        = NULL;
+       int     error           = 0;
 
        if (to_send > fd_cb->send_window) {
                to_send = fd_cb->send_window;
@@ -1200,57 +1290,94 @@ flow_divert_send_app_data(struct flow_divert_pcb *fd_cb, mbuf_t data)
                to_send = 0;    /* If the send buffer is non-empty, then we can't send anything */
        }
 
-       while (sent < to_send) {
-               size_t  pkt_data_len;
+       if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
+               size_t  sent            = 0;
+               mbuf_t  remaining_data  = data;
+               mbuf_t  pkt_data        = NULL;
+               while (sent < to_send) {
+                       size_t  pkt_data_len;
+
+                       pkt_data = remaining_data;
 
-               pkt_data = remaining_data;
+                       if ((to_send - sent) > FLOW_DIVERT_CHUNK_SIZE) {
+                               pkt_data_len = FLOW_DIVERT_CHUNK_SIZE;
+                       } else {
+                               pkt_data_len = to_send - sent;
+                       }
+
+                       if (pkt_data_len < mbuf_pkthdr_len(pkt_data)) {
+                               error = mbuf_split(pkt_data, pkt_data_len, MBUF_DONTWAIT, &remaining_data);
+                               if (error) {
+                                       FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
+                                       pkt_data = NULL;
+                                       break;
+                               }
+                       } else {
+                               remaining_data = NULL;
+                       }
+
+                       error = flow_divert_send_data_packet(fd_cb, pkt_data, pkt_data_len, NULL, FALSE);
 
-               if ((to_send - sent) > FLOW_DIVERT_CHUNK_SIZE) {
-                       pkt_data_len = FLOW_DIVERT_CHUNK_SIZE;
-                       error = mbuf_split(pkt_data, pkt_data_len, MBUF_DONTWAIT, &remaining_data);
                        if (error) {
-                               FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
-                               pkt_data = NULL;
                                break;
                        }
-               } else {
-                       pkt_data_len = to_send - sent;
-                       remaining_data = NULL;
-               }
 
-               error = flow_divert_send_data_packet(fd_cb, pkt_data, pkt_data_len, FALSE);
-
-               if (error) {
-                       break;
+                       pkt_data = NULL;
+                       sent += pkt_data_len;
                }
 
-               pkt_data = NULL;
-               sent += pkt_data_len;
-       }
+               fd_cb->send_window -= sent;
 
-       fd_cb->send_window -= sent;
+               error = 0;
 
-       error = 0;
-
-       if (pkt_data != NULL) {
-               if (sbspace(&fd_cb->so->so_snd) > 0) {
-                       if (!sbappendstream(&fd_cb->so->so_snd, pkt_data)) {
-                               FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with pkt_data, send buffer size = %u, send_window = %u\n",
-                                               fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
+               if (pkt_data != NULL) {
+                       if (sbspace(&fd_cb->so->so_snd) > 0) {
+                               if (!sbappendstream(&fd_cb->so->so_snd, pkt_data)) {
+                                       FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with pkt_data, send buffer size = %u, send_window = %u\n",
+                                                       fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
+                               }
+                       } else {
+                               error = ENOBUFS;
                        }
-               } else {
-                       error = ENOBUFS;
                }
-       }
 
-       if (remaining_data != NULL) {
-               if (sbspace(&fd_cb->so->so_snd) > 0) {
-                       if (!sbappendstream(&fd_cb->so->so_snd, remaining_data)) {
-                               FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with remaining_data, send buffer size = %u, send_window = %u\n",
-                                               fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
+               if (remaining_data != NULL) {
+                       if (sbspace(&fd_cb->so->so_snd) > 0) {
+                               if (!sbappendstream(&fd_cb->so->so_snd, remaining_data)) {
+                                       FDLOG(LOG_ERR, fd_cb, "sbappendstream failed with remaining_data, send buffer size = %u, send_window = %u\n",
+                                                       fd_cb->so->so_snd.sb_cc, fd_cb->send_window);
+                               }
+                       } else {
+                               error = ENOBUFS;
+                       }
+               }
+       } else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
+               if (to_send) {
+                       error = flow_divert_send_data_packet(fd_cb, data, to_send, toaddr, FALSE);
+                       if (error) {
+                               FDLOG(LOG_ERR, fd_cb, "flow_divert_send_data_packet failed. send data size = %u", to_send);
+                       } else {
+                               fd_cb->send_window -= to_send;
                        }
                } else {
-                       error = ENOBUFS;
+                       /* buffer it */
+                       if (sbspace(&fd_cb->so->so_snd) >= (int)mbuf_pkthdr_len(data)) {
+                               if (toaddr != NULL) {
+                                       if (!sbappendaddr(&fd_cb->so->so_snd, toaddr, data, NULL, &error)) {
+                                               FDLOG(LOG_ERR, fd_cb,
+                                                       "sbappendaddr failed. send buffer size = %u, send_window = %u, error = %d\n",
+                                                       fd_cb->so->so_snd.sb_cc, fd_cb->send_window, error);
+                                       }
+                               } else {
+                                       if (!sbappendrecord(&fd_cb->so->so_snd, data)) {
+                                               FDLOG(LOG_ERR, fd_cb,
+                                                       "sbappendrecord failed. send buffer size = %u, send_window = %u, error = %d\n",
+                                                       fd_cb->so->so_snd.sb_cc, fd_cb->send_window, error);
+                                       }
+                               }
+                       } else {
+                               error = ENOBUFS;
+                       }
                }
        }
 
@@ -1408,14 +1535,15 @@ flow_divert_handle_connect_result(struct flow_divert_pcb *fd_cb, mbuf_t packet,
                        goto set_socket_state;
                }
 
-               if (local_address.ss_family != 0) {
+               if (local_address.ss_family == 0 && fd_cb->local_address == NULL) {
+                       error = EINVAL;
+                       goto set_socket_state;
+               }
+               if (local_address.ss_family != 0 && fd_cb->local_address == NULL) {
                        if (local_address.ss_len > sizeof(local_address)) {
                                local_address.ss_len = sizeof(local_address);
                        }
                        fd_cb->local_address = dup_sockaddr((struct sockaddr *)&local_address, 1);
-               } else {
-                       error = EINVAL;
-                       goto set_socket_state;
                }
 
                if (remote_address.ss_family != 0) {
@@ -1482,7 +1610,7 @@ set_socket_state:
                                flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
                                fd_cb->so->so_error = connect_error;
                        }
-                       soisdisconnected(fd_cb->so);
+                       flow_divert_disconnect_socket(fd_cb->so);
                } else {
                        soisconnected(fd_cb->so);
                }
@@ -1528,7 +1656,7 @@ flow_divert_handle_close(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offse
                
                how = flow_divert_tunnel_how_closed(fd_cb);
                if (how == SHUT_RDWR) {
-                       soisdisconnected(fd_cb->so);
+                       flow_divert_disconnect_socket(fd_cb->so);
                } else if (how == SHUT_RD) {
                        socantrcvmore(fd_cb->so);
                } else if (how == SHUT_WR) {
@@ -1540,49 +1668,119 @@ flow_divert_handle_close(struct flow_divert_pcb *fd_cb, mbuf_t packet, int offse
        FDUNLOCK(fd_cb);
 }
 
-static void
-flow_divert_handle_data(struct flow_divert_pcb *fd_cb, mbuf_t packet, size_t offset)
+static mbuf_t
+flow_divert_get_control_mbuf(struct flow_divert_pcb *fd_cb)
 {
-       int             error           = 0;
-       mbuf_t  data            = NULL;
-       size_t  data_size;
+       struct inpcb *inp = sotoinpcb(fd_cb->so);
+       if (inp->inp_vflag & INP_IPV4 && inp->inp_flags & INP_RECVDSTADDR) {
+               struct sockaddr_in *sin = (struct sockaddr_in *)(void *)fd_cb->local_address;
 
-       data_size = (mbuf_pkthdr_len(packet) - offset);
+               return sbcreatecontrol((caddr_t) &sin->sin_addr, sizeof(struct in_addr), IP_RECVDSTADDR, IPPROTO_IP);
+       } else if (inp->inp_vflag & INP_IPV6 && (inp->inp_flags & IN6P_PKTINFO) != 0) {
+               struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)(void *)fd_cb->local_address;
+               struct in6_pktinfo pi6;
 
-       FDLOG(LOG_DEBUG, fd_cb, "received %lu bytes of data", data_size);
-
-       error = mbuf_split(packet, offset, MBUF_DONTWAIT, &data);
-       if (error || data == NULL) {
-               FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
-               return;
+               bcopy(&sin6->sin6_addr, &pi6.ipi6_addr, sizeof (struct in6_addr));
+               pi6.ipi6_ifindex = 0;
+               return sbcreatecontrol((caddr_t)&pi6, sizeof (struct in6_pktinfo), IPV6_PKTINFO, IPPROTO_IPV6);
        }
+       return (NULL);
+}
 
+static void
+flow_divert_handle_data(struct flow_divert_pcb *fd_cb, mbuf_t packet, size_t offset)
+{
        FDLOCK(fd_cb);
        if (fd_cb->so != NULL) {
+               int             error           = 0;
+               mbuf_t  data            = NULL;
+               size_t  data_size;
+               struct sockaddr_storage remote_address;
+               boolean_t got_remote_sa = FALSE;
+
                socket_lock(fd_cb->so, 0);
-               if (flow_divert_check_no_cellular(fd_cb) || 
-                   flow_divert_check_no_expensive(fd_cb)) {
-                       flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
-                       flow_divert_send_close(fd_cb, SHUT_RDWR);
-                       soisdisconnected(fd_cb->so);
-               } else if (!(fd_cb->so->so_state & SS_CANTRCVMORE)) {
-                       if (sbappendstream(&fd_cb->so->so_rcv, data)) {
-                               fd_cb->bytes_received += data_size;
-                               flow_divert_add_data_statistics(fd_cb, data_size, FALSE);
-                               fd_cb->sb_size = fd_cb->so->so_rcv.sb_cc;
-                               sorwakeup(fd_cb->so);
-                               data = NULL;
+
+               if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
+                       uint32_t val_size = 0;
+
+                       /* check if we got remote address with data */
+                       memset(&remote_address, 0, sizeof(remote_address));
+                       error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_REMOTE_ADDR, sizeof(remote_address), &remote_address, &val_size);
+                       if (error || val_size > sizeof(remote_address)) {
+                               FDLOG0(LOG_INFO, fd_cb, "No remote address provided");
+                               error = 0;
                        } else {
-                               FDLOG0(LOG_ERR, fd_cb, "received data, but appendstream failed");
+                               /* validate the address */
+                               if (flow_divert_is_sockaddr_valid((struct sockaddr *)&remote_address)) {
+                                       got_remote_sa = TRUE;
+                               }
+                               offset += (sizeof(uint8_t) + sizeof(uint32_t) + val_size);
+                       }
+               }
+
+               data_size = (mbuf_pkthdr_len(packet) - offset);
+
+               FDLOG(LOG_DEBUG, fd_cb, "received %lu bytes of data", data_size);
+
+               error = mbuf_split(packet, offset, MBUF_DONTWAIT, &data);
+               if (error || data == NULL) {
+                       FDLOG(LOG_ERR, fd_cb, "mbuf_split failed: %d", error);
+               } else {
+                       if (flow_divert_check_no_cellular(fd_cb) || 
+                           flow_divert_check_no_expensive(fd_cb))
+                       {
+                               flow_divert_update_closed_state(fd_cb, SHUT_RDWR, TRUE);
+                               flow_divert_send_close(fd_cb, SHUT_RDWR);
+                               flow_divert_disconnect_socket(fd_cb->so);
+                       } else if (!(fd_cb->so->so_state & SS_CANTRCVMORE)) {
+                               if (SOCK_TYPE(fd_cb->so) == SOCK_STREAM) {
+                                       if (sbappendstream(&fd_cb->so->so_rcv, data)) {
+                                               fd_cb->bytes_received += data_size;
+                                               flow_divert_add_data_statistics(fd_cb, data_size, FALSE);
+                                               fd_cb->sb_size = fd_cb->so->so_rcv.sb_cc;
+                                               sorwakeup(fd_cb->so);
+                                               data = NULL;
+                                       } else {
+                                               FDLOG0(LOG_ERR, fd_cb, "received data, but appendstream failed");
+                                       }
+                               } else if (SOCK_TYPE(fd_cb->so) == SOCK_DGRAM) {
+                                       struct sockaddr *append_sa;
+                                       mbuf_t mctl;
+
+                                       if (got_remote_sa == TRUE) {
+                                               error = flow_divert_dup_addr(fd_cb->so->so_proto->pr_domain->dom_family,
+                                                               (struct sockaddr *)&remote_address, &append_sa);
+                                       } else {
+                                               error = flow_divert_dup_addr(fd_cb->so->so_proto->pr_domain->dom_family,
+                                                               fd_cb->remote_address, &append_sa);
+                                       }
+                                       if (error) {
+                                               FDLOG0(LOG_ERR, fd_cb, "failed to dup the socket address.");
+                                       }
+
+                                       mctl = flow_divert_get_control_mbuf(fd_cb);
+                                       if (sbappendaddr(&fd_cb->so->so_rcv, append_sa, data, mctl, NULL)) {
+                                               fd_cb->bytes_received += data_size;
+                                               flow_divert_add_data_statistics(fd_cb, data_size, FALSE);
+                                               fd_cb->sb_size = fd_cb->so->so_rcv.sb_cc;
+                                               sorwakeup(fd_cb->so);
+                                               data = NULL;
+                                       } else {
+                                               FDLOG0(LOG_ERR, fd_cb, "received data, but sbappendaddr failed");
+                                       }
+                                       if (!error) {
+                                               FREE(append_sa, M_TEMP);
+                                       }
+                               }
                        }
                }
                socket_unlock(fd_cb->so, 0);
-       }
-       FDUNLOCK(fd_cb);
 
-       if (data != NULL) {
-               mbuf_free(data);
+               if (data != NULL) {
+                       mbuf_freem(data);
+               }
        }
+       FDUNLOCK(fd_cb);
 }
 
 static void
@@ -1597,7 +1795,7 @@ flow_divert_handle_read_notification(struct flow_divert_pcb *fd_cb, mbuf_t packe
                return;
        }
 
-       FDLOG(LOG_DEBUG, fd_cb, "received a read notification for %u bytes", read_count);
+       FDLOG(LOG_DEBUG, fd_cb, "received a read notification for %u bytes", ntohl(read_count));
 
        FDLOCK(fd_cb);
        if (fd_cb->so != NULL) {
@@ -1613,7 +1811,7 @@ static void
 flow_divert_handle_group_init(struct flow_divert_group *group, mbuf_t packet, int offset)
 {
        int error = 0;
-       size_t key_size = 0;
+       uint32_t key_size = 0;
        int log_level;
 
        error = flow_divert_packet_get_tlv(packet, offset, FLOW_DIVERT_TLV_TOKEN_KEY, 0, NULL, &key_size);
@@ -1747,7 +1945,7 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
             cursor >= 0;
             cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
        {
-               size_t sid_size = 0;
+               uint32_t sid_size = 0;
                flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
                new_trie.bytes_count += sid_size;
                signing_id_count++;
@@ -1795,7 +1993,7 @@ flow_divert_handle_app_map_create(mbuf_t packet, int offset)
             cursor >= 0;
             cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
        {
-               size_t sid_size = 0;
+               uint32_t sid_size = 0;
                flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
                if (new_trie.bytes_free_next + sid_size <= new_trie.bytes_count) {
                        boolean_t is_dns;
@@ -1848,7 +2046,7 @@ flow_divert_handle_app_map_update(struct flow_divert_group *group, mbuf_t packet
             cursor >= 0;
             cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
        {
-               size_t sid_size = 0;
+               uint32_t sid_size = 0;
                flow_divert_packet_get_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
                if (sid_size > max_size) {
                        max_size = sid_size;
@@ -1865,7 +2063,7 @@ flow_divert_handle_app_map_update(struct flow_divert_group *group, mbuf_t packet
             cursor >= 0;
             cursor = flow_divert_packet_find_tlv(packet, cursor, FLOW_DIVERT_TLV_SIGNING_ID, &error, 1))
        {
-               size_t signing_id_len = 0;
+               uint32_t signing_id_len = 0;
                uint16_t node;
 
                flow_divert_packet_get_tlv(packet,
@@ -1904,6 +2102,12 @@ flow_divert_input(mbuf_t packet, struct flow_divert_group *group)
                goto done;
        }
 
+       if (mbuf_pkthdr_len(packet) > FD_CTL_RCVBUFF_SIZE) {
+               FDLOG(LOG_ERR, &nil_pcb, "got a bad packet, length (%lu) > %lu", mbuf_pkthdr_len(packet), FD_CTL_RCVBUFF_SIZE);
+               error = EINVAL;
+               goto done;
+       }
+
        error = mbuf_copydata(packet, 0, sizeof(hdr), &hdr);
        if (error) {
                FDLOG(LOG_ERR, &nil_pcb, "mbuf_copydata failed for the header: %d", error);
@@ -1963,7 +2167,7 @@ flow_divert_input(mbuf_t packet, struct flow_divert_group *group)
        FDRELEASE(fd_cb);
 
 done:
-       mbuf_free(packet);
+       mbuf_freem(packet);
        return error;
 }
 
@@ -2018,6 +2222,8 @@ flow_divert_detach(struct socket *so)
                /* Last-ditch effort to send any buffered data */
                flow_divert_send_buffered_data(fd_cb, TRUE);
 
+               flow_divert_update_closed_state(fd_cb, SHUT_RDWR, FALSE);
+               flow_divert_send_close_if_needed(fd_cb);
                /* Remove from the group */
                flow_divert_pcb_remove(fd_cb);
        }
@@ -2040,8 +2246,10 @@ flow_divert_close(struct socket *so)
 
        FDLOG0(LOG_INFO, fd_cb, "Closing");
 
-       soisdisconnecting(so);
-       sbflush(&so->so_rcv);
+       if (SOCK_TYPE(so) == SOCK_STREAM) {
+               soisdisconnecting(so);
+               sbflush(&so->so_rcv);
+       }
 
        flow_divert_send_buffered_data(fd_cb, TRUE);
        flow_divert_update_closed_state(fd_cb, SHUT_RDWR, FALSE);
@@ -2054,9 +2262,10 @@ flow_divert_close(struct socket *so)
 }
 
 static int
-flow_divert_disconnectx(struct socket *so, associd_t aid, connid_t cid __unused)
+flow_divert_disconnectx(struct socket *so, sae_associd_t aid,
+    sae_connid_t cid __unused)
 {
-       if (aid != ASSOCID_ANY && aid != ASSOCID_ALL) {
+       if (aid != SAE_ASSOCID_ANY && aid != SAE_ASSOCID_ALL) {
                return (EINVAL);
        }
 
@@ -2108,6 +2317,106 @@ flow_divert_rcvd(struct socket *so, int flags __unused)
        return 0;
 }
 
+static int
+flow_divert_append_target_endpoint_tlv(mbuf_t connect_packet, struct sockaddr *toaddr)
+{
+       int error = 0;
+       int port  = 0;
+
+       error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_ADDRESS, toaddr->sa_len, toaddr);
+       if (error) {
+               goto done;
+       }
+
+       if (toaddr->sa_family == AF_INET) {
+               port = ntohs((satosin(toaddr))->sin_port);
+       }
+#if INET6
+       else {
+               port = ntohs((satosin6(toaddr))->sin6_port);
+       }
+#endif
+
+       error = flow_divert_packet_append_tlv(connect_packet, FLOW_DIVERT_TLV_TARGET_PORT, sizeof(port), &port);
+       if (error) {
+               goto done;
+       }
+
+done:
+       return error;
+}
+
+struct sockaddr *
+flow_divert_get_buffered_target_address(mbuf_t buffer)
+{
+       if (buffer != NULL && buffer->m_type == MT_SONAME) {
+               struct sockaddr *toaddr = mtod(buffer, struct sockaddr *);
+               if (toaddr != NULL && flow_divert_is_sockaddr_valid(toaddr)) {
+                       return toaddr;
+               }
+       }
+       return NULL;
+}
+
+static boolean_t
+flow_divert_is_sockaddr_valid(struct sockaddr *addr)
+{
+       switch(addr->sa_family)
+       {
+               case AF_INET:
+                       if (addr->sa_len != sizeof(struct sockaddr_in)) {
+                               return FALSE;
+                       }
+                       break;
+#if INET6
+               case AF_INET6:
+                       if (addr->sa_len != sizeof(struct sockaddr_in6)) {
+                               return FALSE;
+                       }
+                       break;
+#endif /* INET6 */
+               default:
+                       return FALSE;
+       }
+       return TRUE;
+}
+
+static errno_t
+flow_divert_inp_to_sockaddr(const struct inpcb *inp, struct sockaddr **local_socket)
+{
+       int error = 0;
+       union sockaddr_in_4_6 sin46;
+
+       bzero(&sin46, sizeof(sin46));
+       if (inp->inp_vflag & INP_IPV4) {
+               struct sockaddr_in  *sin = &sin46.sin;
+
+               sin->sin_family = AF_INET;
+               sin->sin_len = sizeof(*sin);
+               sin->sin_port = inp->inp_lport;
+               sin->sin_addr = inp->inp_laddr;
+       } else if (inp->inp_vflag & INP_IPV6) {
+               struct sockaddr_in6 *sin6 = &sin46.sin6;
+
+               sin6->sin6_len = sizeof(*sin6);
+               sin6->sin6_family = AF_INET6;
+               sin6->sin6_port = inp->inp_lport;
+               sin6->sin6_addr = inp->in6p_laddr;
+       }
+       *local_socket = dup_sockaddr((struct sockaddr *)&sin46, 1);
+       if (*local_socket == NULL) {
+               error = ENOBUFS;
+       }
+       return (error);
+}
+
+static boolean_t
+flow_divert_has_pcb_local_address(const struct inpcb *inp)
+{
+       return (inp->inp_lport != 0
+               && (inp->inp_laddr.s_addr != INADDR_ANY || !IN6_IS_ADDR_UNSPECIFIED(&inp->in6p_laddr)));
+}
+
 static errno_t
 flow_divert_dup_addr(sa_family_t family, struct sockaddr *addr,
                      struct sockaddr **dup)
@@ -2145,6 +2454,25 @@ flow_divert_dup_addr(sa_family_t family, struct sockaddr *addr,
        return error;
 }
 
+static void
+flow_divert_disconnect_socket(struct socket *so)
+{
+       soisdisconnected(so);
+       if (SOCK_TYPE(so) == SOCK_DGRAM) {
+               struct inpcb *inp = NULL;
+
+               inp = sotoinpcb(so);
+               if (inp != NULL) {
+#if INET6
+                       if (SOCK_CHECK_DOM(so, PF_INET6))
+                               in6_pcbdetach(inp);
+                       else
+#endif /* INET6 */
+                               in_pcbdetach(inp);
+               }
+       }
+}
+
 static errno_t
 flow_divert_getpeername(struct socket *so, struct sockaddr **sa)
 {
@@ -2244,6 +2572,20 @@ flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
                }
        }
 
+       if (fd_cb->local_address != NULL) {
+                error = EALREADY;
+                goto done;
+        } else {
+                if (flow_divert_has_pcb_local_address(inp)) {
+                        error = flow_divert_inp_to_sockaddr(inp, &fd_cb->local_address);
+                        if (error) {
+                                FDLOG0(LOG_ERR, fd_cb, "failed to get the local socket address.");
+                                goto done;
+                        }
+                }
+        }
+
+
        error = flow_divert_packet_init(fd_cb, FLOW_DIVERT_PKT_CONNECT, &connect_packet);
        if (error) {
                goto done;
@@ -2252,7 +2594,7 @@ flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
        error = EPERM;
 
        if (fd_cb->connect_token != NULL) {
-               size_t sid_size = 0;
+               uint32_t sid_size = 0;
                int find_error = flow_divert_packet_get_tlv(fd_cb->connect_token, 0, FLOW_DIVERT_TLV_SIGNING_ID, 0, NULL, &sid_size);
                if (find_error == 0 && sid_size > 0) {
                        MALLOC(signing_id, char *, sid_size + 1, M_TEMP, M_WAITOK | M_ZERO);
@@ -2274,7 +2616,9 @@ flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
                        if (src_proc != PROC_NULL) {
                                proc_lock(src_proc);
                                if (src_proc->p_csflags & CS_VALID) {
-                                       signing_id = (char *)cs_identity_get(src_proc);
+                    const char * cs_id;
+                    cs_id = cs_identity_get(src_proc);
+                    signing_id = __DECONST(char *, cs_id);
                                } else {
                                        FDLOG0(LOG_WARNING, fd_cb, "Signature is invalid");
                                }
@@ -2288,7 +2632,7 @@ flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
                if (signing_id != NULL) {
                        uint16_t result = NULL_TRIE_IDX;
                        lck_rw_lock_shared(&g_flow_divert_group_lck);
-                       result = flow_divert_trie_search(&g_signing_id_trie, (const uint8_t *)signing_id);
+                       result = flow_divert_trie_search(&g_signing_id_trie, (uint8_t *)signing_id);
                        lck_rw_done(&g_flow_divert_group_lck);
                        if (result != NULL_TRIE_IDX) {
                                error = 0;
@@ -2350,7 +2694,7 @@ flow_divert_connect_out(struct socket *so, struct sockaddr *to, proc_t p)
 
 done:
        if (error && connect_packet != NULL) {
-               mbuf_free(connect_packet);
+               mbuf_freem(connect_packet);
        }
        return error;
 }
@@ -2358,8 +2702,8 @@ done:
 static int
 flow_divert_connectx_out_common(struct socket *so, int af,
     struct sockaddr_list **src_sl, struct sockaddr_list **dst_sl,
-    struct proc *p, uint32_t ifscope __unused, associd_t aid __unused,
-    connid_t *pcid, uint32_t flags __unused, void *arg __unused,
+    struct proc *p, uint32_t ifscope __unused, sae_associd_t aid __unused,
+    sae_connid_t *pcid, uint32_t flags __unused, void *arg __unused,
     uint32_t arglen __unused)
 {
        struct sockaddr_entry *src_se = NULL, *dst_se = NULL;
@@ -2395,9 +2739,10 @@ flow_divert_connectx_out_common(struct socket *so, int af,
 static int
 flow_divert_connectx_out(struct socket *so, struct sockaddr_list **src_sl,
     struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
-    associd_t aid, connid_t *pcid, uint32_t flags, void *arg,
-    uint32_t arglen)
+    sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
+    uint32_t arglen, struct uio *uio, user_ssize_t *bytes_written)
 {
+#pragma unused(uio, bytes_written)
        return (flow_divert_connectx_out_common(so, AF_INET, src_sl, dst_sl,
            p, ifscope, aid, pcid, flags, arg, arglen));
 }
@@ -2406,16 +2751,17 @@ flow_divert_connectx_out(struct socket *so, struct sockaddr_list **src_sl,
 static int
 flow_divert_connectx6_out(struct socket *so, struct sockaddr_list **src_sl,
     struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
-    associd_t aid, connid_t *pcid, uint32_t flags, void *arg,
-    uint32_t arglen)
+    sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
+    uint32_t arglen, struct uio *uio, user_ssize_t *bytes_written)
 {
+#pragma unused(uio, bytes_written)
        return (flow_divert_connectx_out_common(so, AF_INET6, src_sl, dst_sl,
            p, ifscope, aid, pcid, flags, arg, arglen));
 }
 #endif /* INET6 */
 
 static int
-flow_divert_getconninfo(struct socket *so, connid_t cid, uint32_t *flags,
+flow_divert_getconninfo(struct socket *so, sae_connid_t cid, uint32_t *flags,
                         uint32_t *ifindex, int32_t *soerror, user_addr_t src, socklen_t *src_len,
                         user_addr_t dst, socklen_t *dst_len, uint32_t *aux_type,
                         user_addr_t aux_data __unused, uint32_t *aux_len)
@@ -2432,7 +2778,7 @@ flow_divert_getconninfo(struct socket *so, connid_t cid, uint32_t *flags,
                goto out;
        }
 
-       if (cid != CONNID_ANY && cid != CONNID_ALL && cid != 1) {
+       if (cid != SAE_CONNID_ANY && cid != SAE_CONNID_ALL && cid != 1) {
                error = EINVAL;
                goto out;
        }
@@ -2605,7 +2951,7 @@ flow_divert_data_out(struct socket *so, int flags, mbuf_t data, struct sockaddr
        FDLOG(LOG_DEBUG, fd_cb, "app wrote %lu bytes", mbuf_pkthdr_len(data));
 
        fd_cb->bytes_written_by_app += mbuf_pkthdr_len(data);
-       error = flow_divert_send_app_data(fd_cb, data);
+       error = flow_divert_send_app_data(fd_cb, data, to);
        if (error) {
                goto done;
        }
@@ -2618,7 +2964,7 @@ flow_divert_data_out(struct socket *so, int flags, mbuf_t data, struct sockaddr
 
 done:
        if (data) {
-               mbuf_free(data);
+               mbuf_freem(data);
        }
        if (control) {
                mbuf_free(control);
@@ -2640,6 +2986,20 @@ flow_divert_set_protosw(struct socket *so)
 #endif /* INET6 */
 }
 
+static void
+flow_divert_set_udp_protosw(struct socket *so)
+{
+        so->so_flags |= SOF_FLOW_DIVERT;
+        if (SOCK_DOM(so) == PF_INET) {
+                so->so_proto = &g_flow_divert_in_udp_protosw;
+        }
+#if INET6
+        else {
+                so->so_proto = (struct protosw *)&g_flow_divert_in6_udp_protosw;
+        }
+#endif  /* INET6 */
+}
+
 static errno_t
 flow_divert_attach(struct socket *so, uint32_t flow_id, uint32_t ctl_unit)
 {
@@ -2679,10 +3039,14 @@ flow_divert_attach(struct socket *so, uint32_t flow_id, uint32_t ctl_unit)
        VERIFY(inp != NULL);
 
        socket_lock(old_so, 0);
-       soisdisconnected(old_so);
+       flow_divert_disconnect_socket(old_so);
        old_so->so_flags &= ~SOF_FLOW_DIVERT;
        old_so->so_fd_pcb = NULL;
-       old_so->so_proto = pffindproto(SOCK_DOM(old_so), IPPROTO_TCP, SOCK_STREAM);
+       if (SOCK_TYPE(old_so) == SOCK_STREAM) {
+               old_so->so_proto = pffindproto(SOCK_DOM(old_so), IPPROTO_TCP, SOCK_STREAM);
+       } else if (SOCK_TYPE(old_so) == SOCK_DGRAM) {
+               old_so->so_proto = pffindproto(SOCK_DOM(old_so), IPPROTO_UDP, SOCK_DGRAM);
+       }
        fd_cb->so = NULL;
        /* Save the output interface */
        ifp = inp->inp_last_outifp;
@@ -2720,6 +3084,44 @@ done:
        return error;
 }
 
+errno_t
+flow_divert_implicit_data_out(struct socket *so, int flags, mbuf_t data, struct sockaddr *to, mbuf_t control, struct proc *p)
+{
+        struct flow_divert_pcb  *fd_cb  = so->so_fd_pcb;
+       struct inpcb *inp;
+        int error = 0;
+
+       inp = sotoinpcb(so);
+       if (inp == NULL) {
+               return (EINVAL);
+       }
+
+        if (fd_cb == NULL) {
+                uint32_t fd_ctl_unit = necp_socket_get_flow_divert_control_unit(inp);
+                if (fd_ctl_unit > 0) {
+                        error = flow_divert_pcb_init(so, fd_ctl_unit);
+                        fd_cb  = so->so_fd_pcb;
+                        if (error != 0 || fd_cb == NULL) {
+                                goto done;
+                        }
+                } else {
+                        error = ENETDOWN;
+                        goto done;
+                }
+        }
+        return flow_divert_data_out(so, flags, data, to, control, p);
+
+done:
+        if (data) {
+                mbuf_freem(data);
+        }
+        if (control) {
+                mbuf_free(control);
+        }
+
+        return error;
+}
+
 errno_t
 flow_divert_pcb_init(struct socket *so, uint32_t ctl_unit)
 {
@@ -2737,11 +3139,14 @@ flow_divert_pcb_init(struct socket *so, uint32_t ctl_unit)
                        FDLOG(LOG_ERR, fd_cb, "pcb insert failed: %d", error);
                        FDRELEASE(fd_cb);
                } else {
-                       fd_cb->log_level = LOG_NOTICE;
                        fd_cb->control_group_unit = ctl_unit;
                        so->so_fd_pcb = fd_cb;
 
-                       flow_divert_set_protosw(so);
+                       if (SOCK_TYPE(so) == SOCK_STREAM) {
+                               flow_divert_set_protosw(so);
+                       } else if (SOCK_TYPE(so) == SOCK_DGRAM) {
+                               flow_divert_set_udp_protosw(so);
+                       }
 
                        FDLOG0(LOG_INFO, fd_cb, "Created");
                }
@@ -2772,8 +3177,8 @@ flow_divert_token_set(struct socket *so, struct sockopt *sopt)
                goto done;
        }
 
-       if (SOCK_TYPE(so) != SOCK_STREAM ||
-           SOCK_PROTO(so) != IPPROTO_TCP ||
+       if ((SOCK_TYPE(so) != SOCK_STREAM && SOCK_TYPE(so) != SOCK_DGRAM) ||
+           (SOCK_PROTO(so) != IPPROTO_TCP && SOCK_PROTO(so) != IPPROTO_UDP) ||
            (SOCK_DOM(so) != PF_INET
 #if INET6
             && SOCK_DOM(so) != PF_INET6
@@ -2783,10 +3188,12 @@ flow_divert_token_set(struct socket *so, struct sockopt *sopt)
                error = EINVAL;
                goto done;
        } else {
-               struct tcpcb *tp = sototcpcb(so);
-               if (tp == NULL || tp->t_state != TCPS_CLOSED) {
-                       error = EINVAL;
-                       goto done;
+               if (SOCK_TYPE(so) == SOCK_STREAM && SOCK_PROTO(so) == IPPROTO_TCP) {
+                       struct tcpcb *tp = sototcpcb(so);
+                       if (tp == NULL || tp->t_state != TCPS_CLOSED) {
+                               error = EINVAL;
+                               goto done;
+                       }
                }
        }
 
@@ -2957,7 +3364,7 @@ done:
 static errno_t
 flow_divert_kctl_connect(kern_ctl_ref kctlref __unused, struct sockaddr_ctl *sac, void **unitinfo)
 {
-       struct flow_divert_group        *new_group;
+       struct flow_divert_group        *new_group      = NULL;
        int                             error           = 0;
 
        if (sac->sc_unit >= GROUP_COUNT_MAX) {
@@ -3193,6 +3600,39 @@ flow_divert_init(void)
        g_flow_divert_in_protosw.pr_filter_head.tqh_last =
            (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
 
+       /* UDP */
+       g_udp_protosw = pffindproto(AF_INET, IPPROTO_UDP, SOCK_DGRAM);
+       VERIFY(g_udp_protosw != NULL);
+
+       memcpy(&g_flow_divert_in_udp_protosw, g_udp_protosw, sizeof(g_flow_divert_in_udp_protosw));
+       memcpy(&g_flow_divert_in_udp_usrreqs, g_udp_protosw->pr_usrreqs, sizeof(g_flow_divert_in_udp_usrreqs));
+
+       g_flow_divert_in_udp_usrreqs.pru_connect = flow_divert_connect_out;
+       g_flow_divert_in_udp_usrreqs.pru_connectx = flow_divert_connectx_out;
+       g_flow_divert_in_udp_usrreqs.pru_control = flow_divert_in_control;
+       g_flow_divert_in_udp_usrreqs.pru_disconnect = flow_divert_close;
+       g_flow_divert_in_udp_usrreqs.pru_disconnectx = flow_divert_disconnectx;
+       g_flow_divert_in_udp_usrreqs.pru_peeraddr = flow_divert_getpeername;
+       g_flow_divert_in_udp_usrreqs.pru_rcvd = flow_divert_rcvd;
+       g_flow_divert_in_udp_usrreqs.pru_send = flow_divert_data_out;
+       g_flow_divert_in_udp_usrreqs.pru_shutdown = flow_divert_shutdown;
+       g_flow_divert_in_udp_usrreqs.pru_sockaddr = flow_divert_getsockaddr;
+       g_flow_divert_in_udp_usrreqs.pru_sosend_list = pru_sosend_list_notsupp;
+       g_flow_divert_in_udp_usrreqs.pru_soreceive_list = pru_soreceive_list_notsupp;
+
+       g_flow_divert_in_udp_protosw.pr_usrreqs = &g_flow_divert_in_usrreqs;
+       g_flow_divert_in_udp_protosw.pr_ctloutput = flow_divert_ctloutput;
+
+       /*
+       * Socket filters shouldn't attach/detach to/from this protosw
+       * since pr_protosw is to be used instead, which points to the
+       * real protocol; if they do, it is a bug and we should panic.
+       */
+       g_flow_divert_in_udp_protosw.pr_filter_head.tqh_first =
+           (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
+       g_flow_divert_in_udp_protosw.pr_filter_head.tqh_last =
+           (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
+
 #if INET6
        g_tcp6_protosw = (struct ip6protosw *)pffindproto(AF_INET6, IPPROTO_TCP, SOCK_STREAM);
 
@@ -3223,6 +3663,39 @@ flow_divert_init(void)
            (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
        g_flow_divert_in6_protosw.pr_filter_head.tqh_last =
            (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
+
+       /* UDP6 */
+       g_udp6_protosw = (struct ip6protosw *)pffindproto(AF_INET6, IPPROTO_UDP, SOCK_DGRAM);
+
+       VERIFY(g_udp6_protosw != NULL);
+
+       memcpy(&g_flow_divert_in6_udp_protosw, g_udp6_protosw, sizeof(g_flow_divert_in6_udp_protosw));
+       memcpy(&g_flow_divert_in6_udp_usrreqs, g_udp6_protosw->pr_usrreqs, sizeof(g_flow_divert_in6_udp_usrreqs));
+
+       g_flow_divert_in6_udp_usrreqs.pru_connect = flow_divert_connect_out;
+       g_flow_divert_in6_udp_usrreqs.pru_connectx = flow_divert_connectx6_out;
+       g_flow_divert_in6_udp_usrreqs.pru_control = flow_divert_in6_control;
+       g_flow_divert_in6_udp_usrreqs.pru_disconnect = flow_divert_close;
+       g_flow_divert_in6_udp_usrreqs.pru_disconnectx = flow_divert_disconnectx;
+       g_flow_divert_in6_udp_usrreqs.pru_peeraddr = flow_divert_getpeername;
+       g_flow_divert_in6_udp_usrreqs.pru_rcvd = flow_divert_rcvd;
+       g_flow_divert_in6_udp_usrreqs.pru_send = flow_divert_data_out;
+       g_flow_divert_in6_udp_usrreqs.pru_shutdown = flow_divert_shutdown;
+       g_flow_divert_in6_udp_usrreqs.pru_sockaddr = flow_divert_getsockaddr;
+       g_flow_divert_in6_udp_usrreqs.pru_sosend_list = pru_sosend_list_notsupp;
+       g_flow_divert_in6_udp_usrreqs.pru_soreceive_list = pru_soreceive_list_notsupp;
+
+       g_flow_divert_in6_udp_protosw.pr_usrreqs = &g_flow_divert_in6_udp_usrreqs;
+       g_flow_divert_in6_udp_protosw.pr_ctloutput = flow_divert_ctloutput;
+       /*
+       * Socket filters shouldn't attach/detach to/from this protosw
+       * since pr_protosw is to be used instead, which points to the
+       * real protocol; if they do, it is a bug and we should panic.
+       */
+       g_flow_divert_in6_udp_protosw.pr_filter_head.tqh_first =
+           (struct socket_filter *)(uintptr_t)0xdeadbeefdeadbeef;
+       g_flow_divert_in6_udp_protosw.pr_filter_head.tqh_last =
+           (struct socket_filter **)(uintptr_t)0xdeadbeefdeadbeef;
 #endif /* INET6 */
 
        flow_divert_grp_attr = lck_grp_attr_alloc_init();