]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/kern/kpi_socket.c
xnu-2050.48.11.tar.gz
[apple/xnu.git] / bsd / kern / kpi_socket.c
index 357e1f40436a4dda94e33de060be4ba3c2266085..3de525cbe2bae4b4f7bff7eae809cd36dce8dd5f 100644 (file)
@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright (c) 2003-2004 Apple Computer, Inc. All rights reserved.
+ * Copyright (c) 2003-2011 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  * 
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  * 
@@ -27,6 +27,7 @@
  */
 
 #define __KPI__
  */
 
 #define __KPI__
+#include <sys/systm.h>
 #include <sys/kernel.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/kernel.h>
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <sys/filio.h>
 #include <sys/uio_internal.h>
 #include <kern/lock.h>
 #include <sys/filio.h>
 #include <sys/uio_internal.h>
 #include <kern/lock.h>
+#include <netinet/in.h>
+#include <libkern/OSAtomic.h>
 
 
-extern void    *memcpy(void *, const void *, size_t);
 extern int soclose_locked(struct socket *so);
 extern int soclose_locked(struct socket *so);
+extern void soclose_wait_locked(struct socket *so);
+extern int so_isdstlocal(struct socket *so);
 
 errno_t sock_send_internal(
        socket_t                        sock,
 
 errno_t sock_send_internal(
        socket_t                        sock,
@@ -53,6 +57,7 @@ errno_t sock_send_internal(
        int                                     flags,
        size_t                          *sentlen);
 
        int                                     flags,
        size_t                          *sentlen);
 
+typedef        void    (*so_upcall)(struct socket *, caddr_t , int );
 
 
 errno_t
 
 
 errno_t
@@ -101,7 +106,7 @@ sock_accept(
                        sock->so_error = ECONNABORTED;
                        break;
                }
                        sock->so_error = ECONNABORTED;
                        break;
                }
-               error = msleep((caddr_t)&sock->so_timeo, mutex_held, PSOCK | PCATCH, "sock_accept", 0);
+               error = msleep((caddr_t)&sock->so_timeo, mutex_held, PSOCK | PCATCH, "sock_accept", NULL);
                if (error) {
                        socket_unlock(sock, 1);
                        return (error);
                if (error) {
                        socket_unlock(sock, 1);
                        return (error);
@@ -117,7 +122,26 @@ sock_accept(
        new_so = TAILQ_FIRST(&sock->so_comp);
        TAILQ_REMOVE(&sock->so_comp, new_so, so_list);
        sock->so_qlen--;
        new_so = TAILQ_FIRST(&sock->so_comp);
        TAILQ_REMOVE(&sock->so_comp, new_so, so_list);
        sock->so_qlen--;
-       socket_unlock(sock, 1); /* release the head */
+
+       /*
+        * Pass the pre-accepted socket to any interested socket filter(s).
+        * Upon failure, the socket would have been closed by the callee.
+        */
+       if (new_so->so_filt != NULL) {
+               /*
+                * Temporarily drop the listening socket's lock before we
+                * hand off control over to the socket filter(s), but keep
+                * a reference so that it won't go away.  We'll grab it
+                * again once we're done with the filter(s).
+                */
+               socket_unlock(sock, 0);
+               if ((error = soacceptfilter(new_so)) != 0) {
+                       /* Drop reference on listening socket */
+                       sodereference(sock);
+                       return (error);
+               }
+               socket_lock(sock, 0);
+       }
 
        if (dosocklock) {
                lck_mtx_assert(new_so->so_proto->pr_getlock(new_so, 0),
 
        if (dosocklock) {
                lck_mtx_assert(new_so->so_proto->pr_getlock(new_so, 0),
@@ -127,12 +151,17 @@ sock_accept(
        
        new_so->so_state &= ~SS_COMP;
        new_so->so_head = NULL;
        
        new_so->so_state &= ~SS_COMP;
        new_so->so_head = NULL;
-       soacceptlock(new_so, &sa, 0);
+       (void) soacceptlock(new_so, &sa, 0);
        
        
+       socket_unlock(sock, 1); /* release the head */
+
        if (callback) {
        if (callback) {
-               new_so->so_upcall = callback;
+               new_so->so_upcall = (so_upcall) callback;
                new_so->so_upcallarg = cookie;
                new_so->so_rcv.sb_flags |= SB_UPCALL;
                new_so->so_upcallarg = cookie;
                new_so->so_rcv.sb_flags |= SB_UPCALL;
+#if CONFIG_SOWUPCALL
+               new_so->so_snd.sb_flags |= SB_UPCALL;
+#endif
        }
        
        if (sa && from)
        }
        
        if (sa && from)
@@ -141,6 +170,15 @@ sock_accept(
                memcpy(from, sa, fromlen);
        }
        if (sa) FREE(sa, M_SONAME);
                memcpy(from, sa, fromlen);
        }
        if (sa) FREE(sa, M_SONAME);
+
+       /*
+        * If the socket has been marked as inactive by sosetdefunct(),
+        * disallow further operations on it.
+        */
+       if (new_so->so_flags & SOF_DEFUNCT) {
+               (void) sodefunct(current_proc(), new_so,
+                   SHUTDOWN_SOCKET_LEVEL_DISCONNECT_INTERNAL);
+       }
        *new_sock = new_so;
        if (dosocklock) 
                socket_unlock(new_so, 1);
        *new_sock = new_so;
        if (dosocklock) 
                socket_unlock(new_so, 1);
@@ -152,9 +190,30 @@ sock_bind(
        socket_t                                sock,
        const struct sockaddr   *to)
 {
        socket_t                                sock,
        const struct sockaddr   *to)
 {
-       if (sock == NULL || to == NULL) return EINVAL;
+       int     error = 0;
+       struct sockaddr *sa = NULL;
+       struct sockaddr_storage ss;
+       boolean_t want_free = TRUE;
+
+       if (sock == NULL || to == NULL) 
+               return EINVAL;
        
        
-       return sobind(sock, (struct sockaddr*)to);
+       if (to->sa_len > sizeof(ss)) {
+               MALLOC(sa, struct sockaddr *, to->sa_len, M_SONAME, M_WAITOK);
+               if (sa == NULL)
+                       return ENOBUFS;
+       } else {
+               sa = (struct sockaddr *)&ss;
+               want_free = FALSE;
+       }
+       memcpy(sa, to, to->sa_len);
+
+       error = sobind(sock, sa);
+       
+       if (sa != NULL && want_free == TRUE)
+               FREE(sa, M_SONAME);     
+
+       return error;
 }
 
 errno_t
 }
 
 errno_t
@@ -165,23 +224,37 @@ sock_connect(
 {
        int     error = 0;
        lck_mtx_t *mutex_held;
 {
        int     error = 0;
        lck_mtx_t *mutex_held;
+       struct sockaddr *sa = NULL;
+       struct sockaddr_storage ss;
+       boolean_t want_free = TRUE;
        
        if (sock == NULL || to == NULL) return EINVAL;
        
        if (sock == NULL || to == NULL) return EINVAL;
+       
+       if (to->sa_len > sizeof(ss)) {
+               MALLOC(sa, struct sockaddr *, to->sa_len, M_SONAME,
+                       (flags & MSG_DONTWAIT) ? M_NOWAIT : M_WAITOK);
+               if (sa == NULL)
+                       return ENOBUFS;
+       } else {
+               sa = (struct sockaddr *)&ss;
+               want_free = FALSE;
+       }
+       memcpy(sa, to, to->sa_len);
 
        socket_lock(sock, 1);
 
        if ((sock->so_state & SS_ISCONNECTING) &&
                ((sock->so_state & SS_NBIO) != 0 ||
                 (flags & MSG_DONTWAIT) != 0)) {
 
        socket_lock(sock, 1);
 
        if ((sock->so_state & SS_ISCONNECTING) &&
                ((sock->so_state & SS_NBIO) != 0 ||
                 (flags & MSG_DONTWAIT) != 0)) {
-               socket_unlock(sock, 1);
-               return EALREADY;
+               error = EALREADY;
+               goto out;
        }
        }
-       error = soconnectlock(sock, (struct sockaddr*)to, 0);
+       error = soconnectlock(sock, sa, 0);
        if (!error) {
                if ((sock->so_state & SS_ISCONNECTING) &&
                        ((sock->so_state & SS_NBIO) != 0 || (flags & MSG_DONTWAIT) != 0)) {
        if (!error) {
                if ((sock->so_state & SS_ISCONNECTING) &&
                        ((sock->so_state & SS_NBIO) != 0 || (flags & MSG_DONTWAIT) != 0)) {
-                       socket_unlock(sock, 1);
-                       return EINPROGRESS;
+                       error = EINPROGRESS;
+                       goto out;
                }
                
                if (sock->so_proto->pr_getlock != NULL)  
                }
                
                if (sock->so_proto->pr_getlock != NULL)  
@@ -191,7 +264,7 @@ sock_connect(
 
                while ((sock->so_state & SS_ISCONNECTING) && sock->so_error == 0) {
                        error = msleep((caddr_t)&sock->so_timeo, mutex_held, PSOCK | PCATCH,
 
                while ((sock->so_state & SS_ISCONNECTING) && sock->so_error == 0) {
                        error = msleep((caddr_t)&sock->so_timeo, mutex_held, PSOCK | PCATCH,
-                               "sock_connect", 0);
+                               "sock_connect", NULL);
                        if (error)
                                break;
                }
                        if (error)
                                break;
                }
@@ -204,7 +277,12 @@ sock_connect(
        else {
                sock->so_state &= ~SS_ISCONNECTING;
        }
        else {
                sock->so_state &= ~SS_ISCONNECTING;
        }
+out:
        socket_unlock(sock, 1);
        socket_unlock(sock, 1);
+
+       if (sa != NULL && want_free == TRUE)
+               FREE(sa, M_SONAME);
+               
        return error;
 }
 
        return error;
 }
 
@@ -293,51 +371,92 @@ sock_nointerrupt(
 }
 
 errno_t
 }
 
 errno_t
-sock_getpeername(
-       socket_t                sock,
-       struct sockaddr *peername,
-       int                             peernamelen)
+sock_getpeername(socket_t sock, struct sockaddr        *peername, int peernamelen)
 {
 {
-       int                             error = 0;
+       int error;
        struct sockaddr *sa = NULL;
        struct sockaddr *sa = NULL;
-       
-       if (sock == NULL || peername == NULL || peernamelen < 0) return EINVAL;
+
+       if (sock == NULL || peername == NULL || peernamelen < 0)
+               return (EINVAL);
+
        socket_lock(sock, 1);
        socket_lock(sock, 1);
-       if ((sock->so_state & (SS_ISCONNECTED|SS_ISCONFIRMING)) == 0) {
+       if (!(sock->so_state & (SS_ISCONNECTED|SS_ISCONFIRMING))) {
                socket_unlock(sock, 1);
                socket_unlock(sock, 1);
-               return ENOTCONN;
+               return (ENOTCONN);
        }
        }
-       error = sock->so_proto->pr_usrreqs->pru_peeraddr(sock, &sa);
-       if (!error)
-       {
-               if (peernamelen > sa->sa_len) peernamelen = sa->sa_len;
+       error = sogetaddr_locked(sock, &sa, 1);
+       socket_unlock(sock, 1);
+       if (error == 0) {
+               if (peernamelen > sa->sa_len)
+                       peernamelen = sa->sa_len;
                memcpy(peername, sa, peernamelen);
                memcpy(peername, sa, peernamelen);
+               FREE(sa, M_SONAME);
        }
        }
-       if (sa) FREE(sa, M_SONAME);
-       socket_unlock(sock, 1);
-       return error;
+       return (error);
 }
 
 errno_t
 }
 
 errno_t
-sock_getsockname(
-       socket_t                sock,
-       struct sockaddr *sockname,
-       int                             socknamelen)
+sock_getsockname(socket_t sock, struct sockaddr        *sockname, int socknamelen)
 {
 {
-       int                             error = 0;
+       int error;
        struct sockaddr *sa = NULL;
        struct sockaddr *sa = NULL;
-       
-       if (sock == NULL || sockname == NULL || socknamelen < 0) return EINVAL;
+
+       if (sock == NULL || sockname == NULL || socknamelen < 0)
+               return (EINVAL);
+
        socket_lock(sock, 1);
        socket_lock(sock, 1);
-       error = sock->so_proto->pr_usrreqs->pru_sockaddr(sock, &sa);
-       if (!error)
-       {
-               if (socknamelen > sa->sa_len) socknamelen = sa->sa_len;
+       error = sogetaddr_locked(sock, &sa, 0);
+       socket_unlock(sock, 1);
+       if (error == 0) {
+               if (socknamelen > sa->sa_len)
+                       socknamelen = sa->sa_len;
                memcpy(sockname, sa, socknamelen);
                memcpy(sockname, sa, socknamelen);
+               FREE(sa, M_SONAME);
        }
        }
-       if (sa) FREE(sa, M_SONAME);
+       return (error);
+}
+
+__private_extern__ int
+sogetaddr_locked(struct socket *so, struct sockaddr **psa, int peer)
+{
+       int error;
+
+       if (so == NULL || psa == NULL)
+               return (EINVAL);
+
+       *psa = NULL;
+       error = peer ? so->so_proto->pr_usrreqs->pru_peeraddr(so, psa) :
+           so->so_proto->pr_usrreqs->pru_sockaddr(so, psa);
+
+       if (error == 0 && *psa == NULL) {
+               error = ENOMEM;
+       } else if (error != 0 && *psa != NULL) {
+               FREE(*psa, M_SONAME);
+               *psa = NULL;
+       }
+       return (error);
+}
+
+errno_t
+sock_getaddr(socket_t sock, struct sockaddr **psa, int peer)
+{
+       int error;
+
+       if (sock == NULL || psa == NULL)
+               return (EINVAL);
+
+       socket_lock(sock, 1);
+       error = sogetaddr_locked(sock, psa, peer);
        socket_unlock(sock, 1);
        socket_unlock(sock, 1);
-       return error;
+       
+       return (error);
+}
+
+void
+sock_freeaddr(struct sockaddr *sa)
+{
+       if (sa != NULL)
+               FREE(sa, M_SONAME);
 }
 
 errno_t
 }
 
 errno_t
@@ -357,7 +476,7 @@ sock_getsockopt(
        sopt.sopt_name = optname;
        sopt.sopt_val = CAST_USER_ADDR_T(optval); 
        sopt.sopt_valsize = *optlen;
        sopt.sopt_name = optname;
        sopt.sopt_val = CAST_USER_ADDR_T(optval); 
        sopt.sopt_valsize = *optlen;
-       sopt.sopt_p = NULL;
+       sopt.sopt_p = kernproc;
        error = sogetopt(sock, &sopt); /* will lock socket */
        if (error == 0) *optlen = sopt.sopt_valsize;
        return error;
        error = sogetopt(sock, &sopt); /* will lock socket */
        if (error == 0) *optlen = sopt.sopt_valsize;
        return error;
@@ -369,7 +488,7 @@ sock_ioctl(
        unsigned long request,
        void *argp)
 {
        unsigned long request,
        void *argp)
 {
-       return soioctl(sock, request, argp, NULL); /* will lock socket */
+       return soioctl(sock, request, argp, kernproc); /* will lock socket */
 }
 
 errno_t
 }
 
 errno_t
@@ -388,10 +507,156 @@ sock_setsockopt(
        sopt.sopt_name = optname;
        sopt.sopt_val = CAST_USER_ADDR_T(optval);
        sopt.sopt_valsize = optlen;
        sopt.sopt_name = optname;
        sopt.sopt_val = CAST_USER_ADDR_T(optval);
        sopt.sopt_valsize = optlen;
-       sopt.sopt_p = NULL;
+       sopt.sopt_p = kernproc;
        return sosetopt(sock, &sopt); /* will lock socket */
 }
 
        return sosetopt(sock, &sopt); /* will lock socket */
 }
 
+/*
+ * This follows the recommended mappings between DSCP code points and WMM access classes
+ */
+static u_int32_t so_tc_from_dscp(u_int8_t dscp);
+static u_int32_t
+so_tc_from_dscp(u_int8_t dscp)
+{
+       u_int32_t tc;
+
+       if (dscp >= 0x30 && dscp <= 0x3f)
+               tc = SO_TC_VO;
+       else if (dscp >= 0x20 && dscp <= 0x2f)
+               tc = SO_TC_VI;
+       else if (dscp >= 0x08 && dscp <= 0x17)
+               tc = SO_TC_BK;
+       else
+               tc = SO_TC_BE;
+
+       return (tc);
+}
+
+errno_t
+sock_settclassopt(
+       socket_t        sock,
+       const void      *optval,
+       size_t          optlen) {
+
+       errno_t error = 0;
+       struct sockopt sopt;
+       int sotc;
+
+       if (sock == NULL || optval == NULL || optlen != sizeof(int)) return EINVAL;
+
+       socket_lock(sock, 1);
+       if (!(sock->so_state & SS_ISCONNECTED)) {
+               /* If the socket is not connected then we don't know 
+                * if the destination is on LAN  or not. Skip
+                * setting traffic class in this case
+                */
+               error = ENOTCONN;
+               goto out;
+       }
+
+       if (sock->so_proto == NULL || sock->so_proto->pr_domain == NULL || sock->so_pcb == NULL) {
+               error = EINVAL;
+               goto out;
+       }
+
+       /*
+        * Set the socket traffic class based on the passed DSCP code point
+        * regardless of the scope of the destination
+        */
+       sotc = so_tc_from_dscp((*(const int *)optval) >> 2);
+
+       sopt.sopt_dir = SOPT_SET;
+       sopt.sopt_val = CAST_USER_ADDR_T(&sotc);
+       sopt.sopt_valsize = sizeof(sotc);
+       sopt.sopt_p = kernproc;
+       sopt.sopt_level = SOL_SOCKET;
+       sopt.sopt_name = SO_TRAFFIC_CLASS;
+
+       socket_unlock(sock, 0);
+       error = sosetopt(sock, &sopt);
+       socket_lock(sock, 0);
+
+       if (error != 0) {
+               printf("sock_settclassopt: sosetopt SO_TRAFFIC_CLASS failed %d\n", error);
+               goto out;
+       }
+
+       /* Check if the destination address is LAN or link local address.
+        * We do not want to set traffic class bits if the destination
+        * is not local 
+        */ 
+       if (!so_isdstlocal(sock)) {
+               goto out;
+       }
+
+       sopt.sopt_dir = SOPT_SET;
+       sopt.sopt_val = CAST_USER_ADDR_T(optval);
+       sopt.sopt_valsize = optlen;
+       sopt.sopt_p = kernproc;
+
+       switch (sock->so_proto->pr_domain->dom_family) {
+       case AF_INET:
+               sopt.sopt_level = IPPROTO_IP;
+               sopt.sopt_name = IP_TOS;
+               break;
+       case AF_INET6:
+               sopt.sopt_level = IPPROTO_IPV6;
+               sopt.sopt_name = IPV6_TCLASS;
+               break;
+       default:
+               error = EINVAL;
+               goto out;
+       }
+       
+       socket_unlock(sock, 1);
+       return sosetopt(sock, &sopt);
+out:
+       socket_unlock(sock, 1);
+       return error;
+}
+
+errno_t
+sock_gettclassopt(
+       socket_t        sock,
+       void            *optval,
+       size_t          *optlen) {
+
+       errno_t         error = 0;
+       struct sockopt  sopt;
+       
+       if (sock == NULL || optval == NULL || optlen == NULL) return EINVAL;
+
+       sopt.sopt_dir = SOPT_GET;
+       sopt.sopt_val = CAST_USER_ADDR_T(optval); 
+       sopt.sopt_valsize = *optlen;
+       sopt.sopt_p = kernproc;
+
+       socket_lock(sock, 1);
+       if (sock->so_proto == NULL || sock->so_proto->pr_domain == NULL) {
+               socket_unlock(sock, 1);
+               return EINVAL;
+       }
+
+       switch (sock->so_proto->pr_domain->dom_family) {
+       case AF_INET:
+               sopt.sopt_level = IPPROTO_IP;
+               sopt.sopt_name = IP_TOS;
+               break;
+       case AF_INET6:
+               sopt.sopt_level = IPPROTO_IPV6;
+               sopt.sopt_name = IPV6_TCLASS;
+               break;
+       default:
+               socket_unlock(sock, 1);
+               return EINVAL;
+
+       }
+       socket_unlock(sock, 1);
+       error = sogetopt(sock, &sopt); /* will lock socket */
+       if (error == 0) *optlen = sopt.sopt_valsize;
+       return error;
+}
+
 errno_t
 sock_listen(
        socket_t        sock,
 errno_t
 sock_listen(
        socket_t        sock,
@@ -423,7 +688,7 @@ sock_receive_internal(
                                                                  &uio_buf[0], sizeof(uio_buf));
        if (msg && data == NULL) {
                int i;
                                                                  &uio_buf[0], sizeof(uio_buf));
        if (msg && data == NULL) {
                int i;
-               struct iovec_32 *tempp = (struct iovec_32 *) msg->msg_iov;
+               struct iovec *tempp = msg->msg_iov;
                
                for (i = 0; i < msg->msg_iovlen; i++) {
                        uio_addiov(auio, CAST_USER_ADDR_T((tempp + i)->iov_base), (tempp + i)->iov_len);
                
                for (i = 0; i < msg->msg_iovlen; i++) {
                        uio_addiov(auio, CAST_USER_ADDR_T((tempp + i)->iov_base), (tempp + i)->iov_len);
@@ -437,19 +702,10 @@ sock_receive_internal(
        
        if (recvdlen)
                *recvdlen = 0;
        
        if (recvdlen)
                *recvdlen = 0;
-       
-       if (msg && msg->msg_control) {
-               if ((size_t)msg->msg_controllen < sizeof(struct cmsghdr)) return EINVAL;
-               if ((size_t)msg->msg_controllen > MLEN) return EINVAL;
-               control = m_get(M_NOWAIT, MT_CONTROL);
-               if (control == NULL) return ENOMEM;
-               memcpy(mtod(control, caddr_t), msg->msg_control, msg->msg_controllen);
-               control->m_len = msg->msg_controllen;
-       }
 
        /* let pru_soreceive handle the socket locking */       
        error = sock->so_proto->pr_usrreqs->pru_soreceive(sock, &fromsa, auio,
 
        /* let pru_soreceive handle the socket locking */       
        error = sock->so_proto->pr_usrreqs->pru_soreceive(sock, &fromsa, auio,
-                               data, control ? &control : NULL, &flags);
+           data, (msg && msg->msg_control) ? &control : NULL, &flags);
        if (error) goto cleanup;
        
        if (recvdlen)
        if (error) goto cleanup;
        
        if (recvdlen)
@@ -493,7 +749,7 @@ sock_receive_internal(
                                clen -= tocopy;
                                m = m->m_next;
                        }
                                clen -= tocopy;
                                m = m->m_next;
                        }
-                       msg->msg_controllen = (u_int32_t)ctlbuf - (u_int32_t)msg->msg_control;
+                       msg->msg_controllen = (uintptr_t)ctlbuf - (uintptr_t)msg->msg_control;
                }
        }
 
                }
        }
 
@@ -552,7 +808,7 @@ sock_send_internal(
        }
        
        if (data == 0 && msg != NULL) {
        }
        
        if (data == 0 && msg != NULL) {
-               struct iovec_32 *tempp = (struct iovec_32 *) msg->msg_iov;
+               struct iovec *tempp = msg->msg_iov;
 
                auio = uio_createwithbuffer(msg->msg_iovlen, 0, UIO_SYSSPACE, UIO_WRITE, 
                                                                  &uio_buf[0], sizeof(uio_buf));
 
                auio = uio_createwithbuffer(msg->msg_iovlen, 0, UIO_SYSSPACE, UIO_WRITE, 
                                                                  &uio_buf[0], sizeof(uio_buf));
@@ -668,7 +924,6 @@ sock_shutdown(
        return soshutdown(sock, how);
 }
 
        return soshutdown(sock, how);
 }
 
-typedef        void    (*so_upcall)(struct socket *sock, void* arg, int waitf);
 
 errno_t
 sock_socket(
 
 errno_t
 sock_socket(
@@ -686,8 +941,13 @@ sock_socket(
        if (error == 0 && callback)
        {
                (*new_so)->so_rcv.sb_flags |= SB_UPCALL;
        if (error == 0 && callback)
        {
                (*new_so)->so_rcv.sb_flags |= SB_UPCALL;
+#if CONFIG_SOWUPCALL
+               (*new_so)->so_snd.sb_flags |= SB_UPCALL;
+#endif
                (*new_so)->so_upcall = (so_upcall)callback;
                (*new_so)->so_upcallarg = context;
                (*new_so)->so_upcall = (so_upcall)callback;
                (*new_so)->so_upcallarg = context;
+               (*new_so)->last_pid = 0;
+               (*new_so)->last_upid = 0;
        }
        return error;
 }
        }
        return error;
 }
@@ -714,19 +974,26 @@ sock_retain(
 
 /* Do we want this to be APPLE_PRIVATE API? */
 void
 
 /* Do we want this to be APPLE_PRIVATE API? */
 void
-sock_release(
-       socket_t        sock)
+sock_release(socket_t sock)
 {
 {
-       if (sock == NULL) return;
+       if (sock == NULL)
+               return;
        socket_lock(sock, 1);
        socket_lock(sock, 1);
+
+       if (sock->so_upcallusecount)
+               soclose_wait_locked(sock);
+
        sock->so_retaincnt--;
        if (sock->so_retaincnt < 0)
        sock->so_retaincnt--;
        if (sock->so_retaincnt < 0)
-               panic("sock_release: negative retain count for sock=%x cnt=%x\n",
-                       sock, sock->so_retaincnt);
-       if ((sock->so_retaincnt == 0) && (sock->so_usecount == 2))
-               soclose_locked(sock); /* close socket only if the FD is not holding it */
-       else
-               sock->so_usecount--;    /* remove extra reference holding the socket */
+               panic("sock_release: negative retain count for sock=%p "
+                   "cnt=%x\n", sock, sock->so_retaincnt);
+       if ((sock->so_retaincnt == 0) && (sock->so_usecount == 2)) {
+               /* close socket only if the FD is not holding it */
+               soclose_locked(sock);
+       } else {
+               /* remove extra reference holding the socket */
+               sock->so_usecount--;
+       }
        socket_unlock(sock, 1);
 }
 
        socket_unlock(sock, 1);
 }
 
@@ -788,3 +1055,129 @@ sock_gettype(
        socket_unlock(sock, 1);
        return 0;
 }
        socket_unlock(sock, 1);
        return 0;
 }
+
+/*
+ * Return the listening socket of a pre-accepted socket.  It returns the
+ * listener (so_head) value of a given socket.  This is intended to be
+ * called by a socket filter during a filter attach (sf_attach) callback.
+ * The value returned by this routine is safe to be used only in the
+ * context of that callback, because we hold the listener's lock across
+ * the sflt_initsock() call.
+ */
+socket_t
+sock_getlistener(socket_t sock)
+{
+       return (sock->so_head);
+}
+
+static inline void
+sock_set_tcp_stream_priority(socket_t sock)
+{
+       if ((sock->so_proto->pr_domain->dom_family == AF_INET || 
+               sock->so_proto->pr_domain->dom_family == AF_INET6) &&
+               sock->so_proto->pr_type == SOCK_STREAM) {
+
+               set_tcp_stream_priority(sock);
+
+       }
+}
+
+/*
+ * Caller must have ensured socket is valid and won't be going away.
+ */
+void
+socket_set_traffic_mgt_flags_locked(socket_t sock, u_int32_t flags)
+{
+       (void) OSBitOrAtomic(flags, &sock->so_traffic_mgt_flags);
+       sock_set_tcp_stream_priority(sock);
+}
+
+void
+socket_set_traffic_mgt_flags(socket_t sock, u_int32_t flags)
+{
+       socket_lock(sock, 1);
+       socket_set_traffic_mgt_flags_locked(sock, flags);
+       socket_unlock(sock, 1);
+}
+
+/*
+ * Caller must have ensured socket is valid and won't be going away.
+ */
+void
+socket_clear_traffic_mgt_flags_locked(socket_t sock, u_int32_t flags)
+{
+       (void) OSBitAndAtomic(~flags, &sock->so_traffic_mgt_flags);
+       sock_set_tcp_stream_priority(sock);
+}
+
+void
+socket_clear_traffic_mgt_flags(socket_t sock, u_int32_t flags)
+{
+       socket_lock(sock, 1);
+       socket_clear_traffic_mgt_flags_locked(sock, flags);
+       socket_unlock(sock, 1);
+}
+
+
+/*
+ * Caller must have ensured socket is valid and won't be going away.
+ */
+errno_t
+socket_defunct(struct proc *p, socket_t so, int level)
+{
+       errno_t retval;
+
+       if (level != SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC &&
+           level != SHUTDOWN_SOCKET_LEVEL_DISCONNECT_ALL)
+               return (EINVAL);
+
+       socket_lock(so, 1);
+       /*
+        * SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC level is meant to tear down
+        * all of mDNSResponder IPC sockets, currently those of AF_UNIX; note
+        * that this is an implementation artifact of mDNSResponder.  We do
+        * a quick test against the socket buffers for SB_UNIX, since that
+        * would have been set by unp_attach() at socket creation time.
+        */
+       if (level == SHUTDOWN_SOCKET_LEVEL_DISCONNECT_SVC &&
+           (so->so_rcv.sb_flags & so->so_snd.sb_flags & SB_UNIX) != SB_UNIX) {
+               socket_unlock(so, 1);
+               return (EOPNOTSUPP);
+       }
+       retval = sosetdefunct(p, so, level, TRUE);
+       if (retval == 0)
+               retval = sodefunct(p, so, level);
+       socket_unlock(so, 1);
+       return (retval);
+}
+
+errno_t
+sock_setupcall(socket_t sock, sock_upcall callback, void* context)
+{
+       if (sock == NULL)
+               return EINVAL;
+
+       /*
+        * Note that we don't wait for any in progress upcall to complete.
+        */
+       socket_lock(sock, 1);
+
+       sock->so_upcall = (so_upcall) callback;
+       sock->so_upcallarg = context;
+       if (callback) {
+               sock->so_rcv.sb_flags |= SB_UPCALL;
+#if CONFIG_SOWUPCALL
+               sock->so_snd.sb_flags |= SB_UPCALL;
+#endif /* CONFIG_SOWUPCALL */
+       } else {
+               sock->so_rcv.sb_flags &= ~SB_UPCALL;
+#if CONFIG_SOWUPCALL
+               sock->so_snd.sb_flags &= ~SB_UPCALL;
+#endif /* CONFIG_SOWUPCALL */
+       }
+       
+       socket_unlock(sock, 1);
+
+       return 0;
+}
+