]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/kern/kpi_socket.c
xnu-4570.41.2.tar.gz
[apple/xnu.git] / bsd / kern / kpi_socket.c
index 2251c3f6db611a07551461a79aada2a5b152dcf5..a7b17264d5ec4ee8b28f9cb086549de36074b3d6 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2003-2016 Apple Inc. All rights reserved.
+ * Copyright (c) 2003-2017 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  *
 #include <sys/protosw.h>
 #include <sys/domain.h>
 #include <sys/mbuf.h>
+#include <sys/mcache.h>
 #include <sys/fcntl.h>
 #include <sys/filio.h>
 #include <sys/uio_internal.h>
 #include <kern/locks.h>
+#include <net/net_api_stats.h>
 #include <netinet/in.h>
 #include <libkern/OSAtomic.h>
+#include <stdbool.h>
 
 static errno_t sock_send_internal(socket_t, const struct msghdr        *,
     mbuf_t, int, size_t        *);
-static void sock_setupcalls_common(socket_t, sock_upcall, void *,
-    sock_upcall, void *);
+
+#undef sock_accept
+#undef sock_socket
+errno_t sock_accept(socket_t so, struct sockaddr *from, int fromlen,
+    int flags, sock_upcall callback, void *cookie, socket_t *new_so);
+errno_t sock_socket(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so);
+
+static errno_t sock_accept_common(socket_t sock, struct sockaddr *from,
+    int fromlen, int flags, sock_upcall callback, void *cookie,
+    socket_t *new_sock, bool is_internal);
+static errno_t sock_socket_common(int domain, int type, int protocol,
+    sock_upcall callback, void *context, socket_t *new_so, bool is_internal);
 
 errno_t
-sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
-    sock_upcall callback, void *cookie, socket_t *new_sock)
+sock_accept_common(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock, bool is_internal)
 {
        struct sockaddr *sa;
        struct socket *new_so;
@@ -73,6 +87,7 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
                socket_unlock(sock, 1);
                return (ENOTSUP);
        }
+check_again:
        if (((flags & MSG_DONTWAIT) != 0 || (sock->so_state & SS_NBIO) != 0) &&
            sock->so_comp.tqh_first == NULL) {
                socket_unlock(sock, 1);
@@ -80,7 +95,7 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
        }
 
        if (sock->so_proto->pr_getlock != NULL)  {
-               mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+               mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
                dosocklock = 1;
        } else {
                mutex_held = sock->so_proto->pr_domain->dom_mtx;
@@ -106,10 +121,28 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
                return (error);
        }
 
+       so_acquire_accept_list(sock, NULL);
+       if (TAILQ_EMPTY(&sock->so_comp)) {
+               so_release_accept_list(sock);
+               goto check_again;
+       }
        new_so = TAILQ_FIRST(&sock->so_comp);
        TAILQ_REMOVE(&sock->so_comp, new_so, so_list);
+       new_so->so_state &= ~SS_COMP;
+       new_so->so_head = NULL;
        sock->so_qlen--;
 
+       so_release_accept_list(sock);
+
+       /*
+        * Count the accepted socket as an in-kernel socket
+        */
+       new_so->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
+       INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
+       if (is_internal) {
+               INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
+       }
+
        /*
         * Pass the pre-accepted socket to any interested socket filter(s).
         * Upon failure, the socket would have been closed by the callee.
@@ -122,7 +155,7 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
                 * again once we're done with the filter(s).
                 */
                socket_unlock(sock, 0);
-               if ((error = soacceptfilter(new_so)) != 0) {
+               if ((error = soacceptfilter(new_so, sock)) != 0) {
                        /* Drop reference on listening socket */
                        sodereference(sock);
                        return (error);
@@ -131,20 +164,22 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
        }
 
        if (dosocklock) {
-               lck_mtx_assert(new_so->so_proto->pr_getlock(new_so, 0),
+               LCK_MTX_ASSERT(new_so->so_proto->pr_getlock(new_so, 0),
                    LCK_MTX_ASSERT_NOTOWNED);
                socket_lock(new_so, 1);
        }
 
-       new_so->so_state &= ~SS_COMP;
-       new_so->so_head = NULL;
        (void) soacceptlock(new_so, &sa, 0);
 
        socket_unlock(sock, 1); /* release the head */
 
        /* see comments in sock_setupcall() */
        if (callback != NULL) {
-               sock_setupcalls_common(new_so, callback, cookie, NULL, NULL);
+#if CONFIG_EMBEDDED
+               sock_setupcalls_locked(new_so, callback, cookie, callback, cookie, 0);
+#else
+               sock_setupcalls_locked(new_so, callback, cookie, NULL, NULL, 0);
+#endif /* !CONFIG_EMBEDDED */
        }
 
        if (sa != NULL && from != NULL) {
@@ -169,6 +204,22 @@ sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
        return (error);
 }
 
+errno_t
+sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock)
+{
+       return (sock_accept_common(sock, from, fromlen, flags,
+           callback, cookie, new_sock, false));
+}
+
+errno_t
+sock_accept_internal(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock)
+{
+       return (sock_accept_common(sock, from, fromlen, flags,
+           callback, cookie, new_sock, true));
+}
+
 errno_t
 sock_bind(socket_t sock, const struct sockaddr *to)
 {
@@ -238,7 +289,7 @@ sock_connect(socket_t sock, const struct sockaddr *to, int flags)
                }
 
                if (sock->so_proto->pr_getlock != NULL)
-                       mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+                       mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
                else
                        mutex_held = sock->so_proto->pr_domain->dom_mtx;
 
@@ -304,7 +355,7 @@ sock_connectwait(socket_t sock, const struct timeval *tv)
        }
 
        if (sock->so_proto->pr_getlock != NULL)
-               mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+               mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
        else
                mutex_held = sock->so_proto->pr_domain->dom_mtx;
 
@@ -883,10 +934,9 @@ sock_shutdown(socket_t sock, int how)
        return (soshutdown(sock, how));
 }
 
-
 errno_t
-sock_socket(int        domain, int type, int protocol, sock_upcall callback,
-    void *context, socket_t *new_so)
+sock_socket_common(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so, bool is_internal)
 {
        int error = 0;
 
@@ -896,10 +946,18 @@ sock_socket(int   domain, int type, int protocol, sock_upcall callback,
        /* socreate will create an initial so_count */
        error = socreate(domain, new_so, type, protocol);
        if (error == 0) {
+               /*
+                * This is an in-kernel socket
+                */
+               (*new_so)->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
+               INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
+               if (is_internal) {
+                       INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
+               }
+
                /* see comments in sock_setupcall() */
                if (callback != NULL) {
-                       sock_setupcalls_common(*new_so, callback, context,
-                           NULL, NULL);
+                       sock_setupcall(*new_so, callback, context);
                }
                /* 
                 * last_pid and last_upid should be zero for sockets
@@ -911,6 +969,22 @@ sock_socket(int    domain, int type, int protocol, sock_upcall callback,
        return (error);
 }
 
+errno_t
+sock_socket_internal(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so)
+{
+       return (sock_socket_common(domain, type, protocol, callback,
+           context, new_so, true));
+}
+
+errno_t
+sock_socket(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so)
+{
+       return (sock_socket_common(domain, type, protocol, callback,
+           context, new_so, false));
+}
+
 void
 sock_close(socket_t sock)
 {
@@ -961,6 +1035,7 @@ sock_release(socket_t sock)
                soclose_locked(sock);
        } else {
                /* remove extra reference holding the socket */
+               VERIFY(sock->so_usecount > 1);
                sock->so_usecount--;
        }
        socket_unlock(sock, 1);
@@ -1125,26 +1200,30 @@ socket_defunct(struct proc *p, socket_t so, int level)
        return (retval);
 }
 
-static void
-sock_setupcalls_common(socket_t sock, sock_upcall rcallback, void *rcontext,
-    sock_upcall wcallback, void *wcontext)
+void
+sock_setupcalls_locked(socket_t sock, sock_upcall rcallback, void *rcontext,
+    sock_upcall wcallback, void *wcontext, int locked)
 {
        if (rcallback != NULL) {
                sock->so_rcv.sb_flags |= SB_UPCALL;
+               if (locked)
+                       sock->so_rcv.sb_flags |= SB_UPCALL_LOCK;
                sock->so_rcv.sb_upcall = rcallback;
                sock->so_rcv.sb_upcallarg = rcontext;
        } else {
-               sock->so_rcv.sb_flags &= ~SB_UPCALL;
+               sock->so_rcv.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
                sock->so_rcv.sb_upcall = NULL;
                sock->so_rcv.sb_upcallarg = NULL;
        }
 
        if (wcallback != NULL) {
                sock->so_snd.sb_flags |= SB_UPCALL;
+               if (locked)
+                       sock->so_snd.sb_flags |= SB_UPCALL_LOCK;
                sock->so_snd.sb_upcall = wcallback;
                sock->so_snd.sb_upcallarg = wcontext;
        } else {
-               sock->so_snd.sb_flags &= ~SB_UPCALL;
+               sock->so_snd.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
                sock->so_snd.sb_upcall = NULL;
                sock->so_snd.sb_upcallarg = NULL;
        }
@@ -1166,7 +1245,11 @@ sock_setupcall(socket_t sock, sock_upcall callback, void *context)
         * the read and write callbacks and their respective parameters.
         */
        socket_lock(sock, 1);
-       sock_setupcalls_common(sock, callback, context, NULL, NULL);
+#if CONFIG_EMBEDDED
+       sock_setupcalls_locked(sock, callback, context, callback, context, 0);
+#else
+       sock_setupcalls_locked(sock, callback, context, NULL, NULL, 0);
+#endif /* !CONFIG_EMBEDDED */
        socket_unlock(sock, 1);
 
        return (0);
@@ -1183,23 +1266,21 @@ sock_setupcalls(socket_t sock, sock_upcall rcallback, void *rcontext,
         * Note that we don't wait for any in progress upcall to complete.
         */
        socket_lock(sock, 1);
-       sock_setupcalls_common(sock, rcallback, rcontext, wcallback, wcontext);
+       sock_setupcalls_locked(sock, rcallback, rcontext, wcallback, wcontext, 0);
        socket_unlock(sock, 1);
 
        return (0);
 }
 
-errno_t
-sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
+void
+sock_catchevents_locked(socket_t sock, sock_evupcall ecallback, void *econtext,
     u_int32_t emask)
 {
-       if (sock == NULL)
-               return (EINVAL);
+       socket_lock_assert_owned(sock);
 
        /*
         * Note that we don't wait for any in progress upcall to complete.
         */
-       socket_lock(sock, 1);
        if (ecallback != NULL) {
                sock->so_event = ecallback;
                sock->so_eventarg = econtext;
@@ -1209,6 +1290,17 @@ sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
                sock->so_eventarg = NULL;
                sock->so_eventmask = 0;
        }
+}
+
+errno_t
+sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
+    u_int32_t emask)
+{
+       if (sock == NULL)
+               return (EINVAL);
+
+       socket_lock(sock, 1);
+       sock_catchevents_locked(sock, ecallback, econtext, emask);
        socket_unlock(sock, 1);
 
        return (0);