]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/mp_pcb.c
xnu-4570.1.46.tar.gz
[apple/xnu.git] / bsd / netinet / mp_pcb.c
index 31ea83fd0b6a9022a2d9c451e33bdbe3945a64e1..288f29c717b91a1d143cc4a6ee7ea64cd4b51223 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2012-2016 Apple Inc. All rights reserved.
+ * Copyright (c) 2012-2017 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  *
@@ -43,6 +43,7 @@
 
 #include <netinet/mp_pcb.h>
 #include <netinet/mptcp_var.h>
+#include <netinet6/in6_pcb.h>
 
 static lck_grp_t       *mp_lock_grp;
 static lck_attr_t      *mp_lock_attr;
@@ -131,7 +132,7 @@ mp_timeout(void *arg)
 static void
 mp_sched_timeout(void)
 {
-       lck_mtx_assert(&mp_timeout_lock, LCK_MTX_ASSERT_OWNED);
+       LCK_MTX_ASSERT(&mp_timeout_lock, LCK_MTX_ASSERT_OWNED);
 
        if (!mp_timeout_run && (mp_garbage_collecting || mp_ticking)) {
                lck_mtx_convert_spin(&mp_timeout_lock);
@@ -199,27 +200,9 @@ int
 mp_pcballoc(struct socket *so, struct mppcbinfo *mppi)
 {
        struct mppcb *mpp = NULL;
+       int error;
 
-       VERIFY(sotomppcb(so) == NULL);
-
-       lck_mtx_lock(&mppi->mppi_lock);
-       if (mppi->mppi_count >= mptcp_socket_limit) {
-               lck_mtx_unlock(&mppi->mppi_lock);
-               mptcplog((LOG_ERR, "MPTCP Socket: Reached MPTCP socket limit."),
-                   MPTCP_SOCKET_DBG, MPTCP_LOGLVL_ERR);
-               /*
-                * This limit may be reached either because of
-                * a leak or a transient condition where
-                * MPTCP connections are not released fast
-                * enough.
-                * We return EAFNOSUPPORT here to have user
-                * space library fallback to TCP.
-                * XXX We need to revist this when we get rid
-                * of the current low limit imposed on MPTCP.
-                */
-               return (EAFNOSUPPORT);
-       }
-       lck_mtx_unlock(&mppi->mppi_lock);
+       VERIFY(mpsotomppcb(so) == NULL);
 
        mpp = zalloc(mppi->mppi_zone);
        if (mpp == NULL) {
@@ -233,10 +216,11 @@ mp_pcballoc(struct socket *so, struct mppcbinfo *mppi)
        mpp->mpp_socket = so;
        so->so_pcb = mpp;
 
-       if (NULL == mppi->mppi_pcbe_create(so, mpp)) {
+       error = mptcp_sescreate(mpp);
+       if (error) {
                lck_mtx_destroy(&mpp->mpp_lock, mppi->mppi_lock_grp);
                zfree(mppi->mppi_zone, mpp);
-               return (ENOBUFS);
+               return (error);
        }
 
        lck_mtx_lock(&mppi->mppi_lock);
@@ -249,15 +233,13 @@ mp_pcballoc(struct socket *so, struct mppcbinfo *mppi)
 }
 
 void
-mp_pcbdetach(struct mppcb *mpp)
+mp_pcbdetach(struct socket *mp_so)
 {
-       struct socket *so = mpp->mpp_socket;
-
-       VERIFY(so->so_pcb == mpp);
+       struct mppcb *mpp = mpsotomppcb(mp_so);
 
        mpp->mpp_state = MPPCB_STATE_DEAD;
-       if (!(so->so_flags & SOF_PCBCLEARING))
-               so->so_flags |= SOF_PCBCLEARING;
+       if (!(mp_so->so_flags & SOF_PCBCLEARING))
+               mp_so->so_flags |= SOF_PCBCLEARING;
 
        mp_gc_sched();
 }
@@ -269,23 +251,105 @@ mp_pcbdispose(struct mppcb *mpp)
 
        VERIFY(mppi != NULL);
 
-       lck_mtx_assert(&mppi->mppi_lock, LCK_MTX_ASSERT_OWNED);
-       lck_mtx_assert(&mpp->mpp_lock, LCK_MTX_ASSERT_OWNED);
+       LCK_MTX_ASSERT(&mppi->mppi_lock, LCK_MTX_ASSERT_OWNED);
+       mpp_lock_assert_held(mpp);
 
        VERIFY(mpp->mpp_state == MPPCB_STATE_DEAD);
-
        VERIFY(mpp->mpp_flags & MPP_ATTACHED);
+
        mpp->mpp_flags &= ~MPP_ATTACHED;
        TAILQ_REMOVE(&mppi->mppi_pcbs, mpp, mpp_entry);
        VERIFY(mppi->mppi_count != 0);
        mppi->mppi_count--;
 
+       mpp_unlock(mpp);
+
+#if NECP
+       necp_mppcb_dispose(mpp);
+#endif /* NECP */
+
+       lck_mtx_destroy(&mpp->mpp_lock, mppi->mppi_lock_grp);
+
        VERIFY(mpp->mpp_socket != NULL);
        VERIFY(mpp->mpp_socket->so_usecount == 0);
        mpp->mpp_socket->so_pcb = NULL;
        mpp->mpp_socket = NULL;
 
-       lck_mtx_unlock(&mpp->mpp_lock);
-       lck_mtx_destroy(&mpp->mpp_lock, mppi->mppi_lock_grp);
        zfree(mppi->mppi_zone, mpp);
 }
+
+static int
+mp_getaddr_v4(struct socket *mp_so, struct sockaddr **nam, boolean_t peer)
+{
+       struct mptses *mpte = mpsotompte(mp_so);
+       struct sockaddr_in *sin;
+
+       /*
+        * Do the malloc first in case it blocks.
+        */
+       MALLOC(sin, struct sockaddr_in *, sizeof (*sin), M_SONAME, M_WAITOK);
+       if (sin == NULL)
+               return (ENOBUFS);
+       bzero(sin, sizeof (*sin));
+       sin->sin_family = AF_INET;
+       sin->sin_len = sizeof (*sin);
+
+       if (!peer) {
+               sin->sin_port = mpte->__mpte_src_v4.sin_port;
+               sin->sin_addr = mpte->__mpte_src_v4.sin_addr;
+       } else {
+               sin->sin_port = mpte->__mpte_dst_v4.sin_port;
+               sin->sin_addr = mpte->__mpte_dst_v4.sin_addr;
+       }
+
+       *nam = (struct sockaddr *)sin;
+       return (0);
+}
+
+static int
+mp_getaddr_v6(struct socket *mp_so, struct sockaddr **nam, boolean_t peer)
+{
+       struct mptses *mpte = mpsotompte(mp_so);
+       struct in6_addr addr;
+       in_port_t port;
+
+       if (!peer) {
+               port = mpte->__mpte_src_v6.sin6_port;
+               addr = mpte->__mpte_src_v6.sin6_addr;
+       } else {
+               port = mpte->__mpte_dst_v6.sin6_port;
+               addr = mpte->__mpte_dst_v6.sin6_addr;
+       }
+
+       *nam = in6_sockaddr(port, &addr);
+       if (*nam == NULL)
+               return (ENOBUFS);
+
+       return (0);
+}
+
+int
+mp_getsockaddr(struct socket *mp_so, struct sockaddr **nam)
+{
+       struct mptses *mpte = mpsotompte(mp_so);
+
+       if (mpte->mpte_src.sa_family == AF_INET || mpte->mpte_src.sa_family == 0)
+               return mp_getaddr_v4(mp_so, nam, false);
+       else if (mpte->mpte_src.sa_family == AF_INET)
+               return mp_getaddr_v6(mp_so, nam, false);
+       else
+               return (EINVAL);
+}
+
+int
+mp_getpeeraddr(struct socket *mp_so, struct sockaddr **nam)
+{
+       struct mptses *mpte = mpsotompte(mp_so);
+
+       if (mpte->mpte_src.sa_family == AF_INET || mpte->mpte_src.sa_family == 0)
+               return mp_getaddr_v4(mp_so, nam, true);
+       else if (mpte->mpte_src.sa_family == AF_INET)
+               return mp_getaddr_v6(mp_so, nam, true);
+       else
+               return (EINVAL);
+}