]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/kern/kpi_socket.c
xnu-4570.41.2.tar.gz
[apple/xnu.git] / bsd / kern / kpi_socket.c
index 2f1b1d96a8af12006babef4a89d17f69d654b35c..a7b17264d5ec4ee8b28f9cb086549de36074b3d6 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2003-2016 Apple Inc. All rights reserved.
+ * Copyright (c) 2003-2017 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  *
 #include <sys/filio.h>
 #include <sys/uio_internal.h>
 #include <kern/locks.h>
+#include <net/net_api_stats.h>
 #include <netinet/in.h>
 #include <libkern/OSAtomic.h>
+#include <stdbool.h>
 
 static errno_t sock_send_internal(socket_t, const struct msghdr        *,
     mbuf_t, int, size_t        *);
-static void sock_setupcalls_common(socket_t, sock_upcall, void *,
-    sock_upcall, void *);
+
+#undef sock_accept
+#undef sock_socket
+errno_t sock_accept(socket_t so, struct sockaddr *from, int fromlen,
+    int flags, sock_upcall callback, void *cookie, socket_t *new_so);
+errno_t sock_socket(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so);
+
+static errno_t sock_accept_common(socket_t sock, struct sockaddr *from,
+    int fromlen, int flags, sock_upcall callback, void *cookie,
+    socket_t *new_sock, bool is_internal);
+static errno_t sock_socket_common(int domain, int type, int protocol,
+    sock_upcall callback, void *context, socket_t *new_so, bool is_internal);
 
 errno_t
-sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
-    sock_upcall callback, void *cookie, socket_t *new_sock)
+sock_accept_common(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock, bool is_internal)
 {
        struct sockaddr *sa;
        struct socket *new_so;
@@ -82,7 +95,7 @@ check_again:
        }
 
        if (sock->so_proto->pr_getlock != NULL)  {
-               mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+               mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
                dosocklock = 1;
        } else {
                mutex_held = sock->so_proto->pr_domain->dom_mtx;
@@ -121,6 +134,15 @@ check_again:
 
        so_release_accept_list(sock);
 
+       /*
+        * Count the accepted socket as an in-kernel socket
+        */
+       new_so->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
+       INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
+       if (is_internal) {
+               INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
+       }
+
        /*
         * Pass the pre-accepted socket to any interested socket filter(s).
         * Upon failure, the socket would have been closed by the callee.
@@ -142,7 +164,7 @@ check_again:
        }
 
        if (dosocklock) {
-               lck_mtx_assert(new_so->so_proto->pr_getlock(new_so, 0),
+               LCK_MTX_ASSERT(new_so->so_proto->pr_getlock(new_so, 0),
                    LCK_MTX_ASSERT_NOTOWNED);
                socket_lock(new_so, 1);
        }
@@ -153,7 +175,11 @@ check_again:
 
        /* see comments in sock_setupcall() */
        if (callback != NULL) {
-               sock_setupcalls_common(new_so, callback, cookie, NULL, NULL);
+#if CONFIG_EMBEDDED
+               sock_setupcalls_locked(new_so, callback, cookie, callback, cookie, 0);
+#else
+               sock_setupcalls_locked(new_so, callback, cookie, NULL, NULL, 0);
+#endif /* !CONFIG_EMBEDDED */
        }
 
        if (sa != NULL && from != NULL) {
@@ -178,6 +204,22 @@ check_again:
        return (error);
 }
 
+errno_t
+sock_accept(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock)
+{
+       return (sock_accept_common(sock, from, fromlen, flags,
+           callback, cookie, new_sock, false));
+}
+
+errno_t
+sock_accept_internal(socket_t sock, struct sockaddr *from, int fromlen, int flags,
+    sock_upcall callback, void *cookie, socket_t *new_sock)
+{
+       return (sock_accept_common(sock, from, fromlen, flags,
+           callback, cookie, new_sock, true));
+}
+
 errno_t
 sock_bind(socket_t sock, const struct sockaddr *to)
 {
@@ -247,7 +289,7 @@ sock_connect(socket_t sock, const struct sockaddr *to, int flags)
                }
 
                if (sock->so_proto->pr_getlock != NULL)
-                       mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+                       mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
                else
                        mutex_held = sock->so_proto->pr_domain->dom_mtx;
 
@@ -313,7 +355,7 @@ sock_connectwait(socket_t sock, const struct timeval *tv)
        }
 
        if (sock->so_proto->pr_getlock != NULL)
-               mutex_held = (*sock->so_proto->pr_getlock)(sock, 0);
+               mutex_held = (*sock->so_proto->pr_getlock)(sock, PR_F_WILLUNLOCK);
        else
                mutex_held = sock->so_proto->pr_domain->dom_mtx;
 
@@ -892,10 +934,9 @@ sock_shutdown(socket_t sock, int how)
        return (soshutdown(sock, how));
 }
 
-
 errno_t
-sock_socket(int        domain, int type, int protocol, sock_upcall callback,
-    void *context, socket_t *new_so)
+sock_socket_common(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so, bool is_internal)
 {
        int error = 0;
 
@@ -905,10 +946,18 @@ sock_socket(int   domain, int type, int protocol, sock_upcall callback,
        /* socreate will create an initial so_count */
        error = socreate(domain, new_so, type, protocol);
        if (error == 0) {
+               /*
+                * This is an in-kernel socket
+                */
+               (*new_so)->so_flags1 |= SOF1_IN_KERNEL_SOCKET;
+               INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_total);
+               if (is_internal) {
+                       INC_ATOMIC_INT64_LIM(net_api_stats.nas_socket_in_kernel_os_total);
+               }
+
                /* see comments in sock_setupcall() */
                if (callback != NULL) {
-                       sock_setupcalls_common(*new_so, callback, context,
-                           NULL, NULL);
+                       sock_setupcall(*new_so, callback, context);
                }
                /* 
                 * last_pid and last_upid should be zero for sockets
@@ -920,6 +969,22 @@ sock_socket(int    domain, int type, int protocol, sock_upcall callback,
        return (error);
 }
 
+errno_t
+sock_socket_internal(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so)
+{
+       return (sock_socket_common(domain, type, protocol, callback,
+           context, new_so, true));
+}
+
+errno_t
+sock_socket(int domain, int type, int protocol, sock_upcall callback,
+    void *context, socket_t *new_so)
+{
+       return (sock_socket_common(domain, type, protocol, callback,
+           context, new_so, false));
+}
+
 void
 sock_close(socket_t sock)
 {
@@ -1135,26 +1200,30 @@ socket_defunct(struct proc *p, socket_t so, int level)
        return (retval);
 }
 
-static void
-sock_setupcalls_common(socket_t sock, sock_upcall rcallback, void *rcontext,
-    sock_upcall wcallback, void *wcontext)
+void
+sock_setupcalls_locked(socket_t sock, sock_upcall rcallback, void *rcontext,
+    sock_upcall wcallback, void *wcontext, int locked)
 {
        if (rcallback != NULL) {
                sock->so_rcv.sb_flags |= SB_UPCALL;
+               if (locked)
+                       sock->so_rcv.sb_flags |= SB_UPCALL_LOCK;
                sock->so_rcv.sb_upcall = rcallback;
                sock->so_rcv.sb_upcallarg = rcontext;
        } else {
-               sock->so_rcv.sb_flags &= ~SB_UPCALL;
+               sock->so_rcv.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
                sock->so_rcv.sb_upcall = NULL;
                sock->so_rcv.sb_upcallarg = NULL;
        }
 
        if (wcallback != NULL) {
                sock->so_snd.sb_flags |= SB_UPCALL;
+               if (locked)
+                       sock->so_snd.sb_flags |= SB_UPCALL_LOCK;
                sock->so_snd.sb_upcall = wcallback;
                sock->so_snd.sb_upcallarg = wcontext;
        } else {
-               sock->so_snd.sb_flags &= ~SB_UPCALL;
+               sock->so_snd.sb_flags &= ~(SB_UPCALL | SB_UPCALL_LOCK);
                sock->so_snd.sb_upcall = NULL;
                sock->so_snd.sb_upcallarg = NULL;
        }
@@ -1176,7 +1245,11 @@ sock_setupcall(socket_t sock, sock_upcall callback, void *context)
         * the read and write callbacks and their respective parameters.
         */
        socket_lock(sock, 1);
-       sock_setupcalls_common(sock, callback, context, NULL, NULL);
+#if CONFIG_EMBEDDED
+       sock_setupcalls_locked(sock, callback, context, callback, context, 0);
+#else
+       sock_setupcalls_locked(sock, callback, context, NULL, NULL, 0);
+#endif /* !CONFIG_EMBEDDED */
        socket_unlock(sock, 1);
 
        return (0);
@@ -1193,23 +1266,21 @@ sock_setupcalls(socket_t sock, sock_upcall rcallback, void *rcontext,
         * Note that we don't wait for any in progress upcall to complete.
         */
        socket_lock(sock, 1);
-       sock_setupcalls_common(sock, rcallback, rcontext, wcallback, wcontext);
+       sock_setupcalls_locked(sock, rcallback, rcontext, wcallback, wcontext, 0);
        socket_unlock(sock, 1);
 
        return (0);
 }
 
-errno_t
-sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
+void
+sock_catchevents_locked(socket_t sock, sock_evupcall ecallback, void *econtext,
     u_int32_t emask)
 {
-       if (sock == NULL)
-               return (EINVAL);
+       socket_lock_assert_owned(sock);
 
        /*
         * Note that we don't wait for any in progress upcall to complete.
         */
-       socket_lock(sock, 1);
        if (ecallback != NULL) {
                sock->so_event = ecallback;
                sock->so_eventarg = econtext;
@@ -1219,6 +1290,17 @@ sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
                sock->so_eventarg = NULL;
                sock->so_eventmask = 0;
        }
+}
+
+errno_t
+sock_catchevents(socket_t sock, sock_evupcall ecallback, void *econtext,
+    u_int32_t emask)
+{
+       if (sock == NULL)
+               return (EINVAL);
+
+       socket_lock(sock, 1);
+       sock_catchevents_locked(sock, ecallback, econtext, emask);
        socket_unlock(sock, 1);
 
        return (0);