]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/mptcp_usrreq.c
xnu-3789.31.2.tar.gz
[apple/xnu.git] / bsd / netinet / mptcp_usrreq.c
index d61ad1fc3c6e3d964c6d128991410648545df9ae..a3118841a5a0dadfd084c00ed7922488e37aa576 100644 (file)
@@ -88,6 +88,7 @@ static int mptcp_setopt(struct mptses *, struct sockopt *);
 static int mptcp_getopt(struct mptses *, struct sockopt *);
 static int mptcp_default_tcp_optval(struct mptses *, struct sockopt *, int *);
 static void mptcp_connorder_helper(struct mptsub *mpts);
 static int mptcp_getopt(struct mptses *, struct sockopt *);
 static int mptcp_default_tcp_optval(struct mptses *, struct sockopt *, int *);
 static void mptcp_connorder_helper(struct mptsub *mpts);
+static int mptcp_usr_preconnect(struct socket *so);
 
 struct pr_usrreqs mptcp_usrreqs = {
        .pru_attach =           mptcp_usr_attach,
 
 struct pr_usrreqs mptcp_usrreqs = {
        .pru_attach =           mptcp_usr_attach,
@@ -103,6 +104,7 @@ struct pr_usrreqs mptcp_usrreqs = {
        .pru_sosend =           mptcp_usr_sosend,
        .pru_soreceive =        soreceive,
        .pru_socheckopt =       mptcp_usr_socheckopt,
        .pru_sosend =           mptcp_usr_sosend,
        .pru_soreceive =        soreceive,
        .pru_socheckopt =       mptcp_usr_socheckopt,
+       .pru_preconnect =       mptcp_usr_preconnect,
 };
 
 /*
 };
 
 /*
@@ -165,6 +167,10 @@ mptcp_attach(struct socket *mp_so, struct proc *p)
                        goto out;
        }
 
                        goto out;
        }
 
+       if (mp_so->so_snd.sb_preconn_hiwat == 0) {
+               soreserve_preconnect(mp_so, 2048);
+       }
+
        /*
         * MPTCP socket buffers cannot be compressed, due to the
         * fact that each mbuf chained via m_next is a M_PKTHDR
        /*
         * MPTCP socket buffers cannot be compressed, due to the
         * fact that each mbuf chained via m_next is a M_PKTHDR
@@ -306,12 +312,12 @@ static int
 mptcp_usr_connectx(struct socket *mp_so, struct sockaddr_list **src_sl,
     struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
     sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
 mptcp_usr_connectx(struct socket *mp_so, struct sockaddr_list **src_sl,
     struct sockaddr_list **dst_sl, struct proc *p, uint32_t ifscope,
     sae_associd_t aid, sae_connid_t *pcid, uint32_t flags, void *arg,
-    uint32_t arglen, struct uio *uio, user_ssize_t *bytes_written)
+    uint32_t arglen, struct uio *auio, user_ssize_t *bytes_written)
 {
 {
-#pragma unused(arg, arglen, uio, bytes_written)
        struct mppcb *mpp = sotomppcb(mp_so);
        struct mptses *mpte = NULL;
        struct mptcb *mp_tp = NULL;
        struct mppcb *mpp = sotomppcb(mp_so);
        struct mptses *mpte = NULL;
        struct mptcb *mp_tp = NULL;
+       user_ssize_t    datalen;
 
        int error = 0;
 
 
        int error = 0;
 
@@ -332,6 +338,33 @@ mptcp_usr_connectx(struct socket *mp_so, struct sockaddr_list **src_sl,
 
        error = mptcp_connectx(mpte, src_sl, dst_sl, p, ifscope,
            aid, pcid, flags, arg, arglen);
 
        error = mptcp_connectx(mpte, src_sl, dst_sl, p, ifscope,
            aid, pcid, flags, arg, arglen);
+
+       /* If there is data, copy it */
+       if (auio != NULL) {
+               datalen = uio_resid(auio);
+               socket_unlock(mp_so, 0);
+               error = mp_so->so_proto->pr_usrreqs->pru_sosend(mp_so, NULL,
+                   (uio_t) auio, NULL, NULL, 0);
+               /* check if this can be supported with fast Join also. XXX */
+               if (error == 0 || error == EWOULDBLOCK)
+                       *bytes_written = datalen - uio_resid(auio);
+
+               if (error == EWOULDBLOCK)
+                       error = EINPROGRESS;
+
+               socket_lock(mp_so, 0);
+               MPT_LOCK(mp_tp);
+               if (mp_tp->mpt_flags & MPTCPF_PEEL_OFF) {
+                       *bytes_written = datalen - uio_resid(auio);
+                       /*
+                        * Override errors like EPIPE that occur as
+                        * a result of doing TFO during TCP fallback.
+                        */
+                       error = EPROTO;
+               }
+               MPT_UNLOCK(mp_tp);
+       }
+
 out:
        return (error);
 }
 out:
        return (error);
 }
@@ -589,7 +622,7 @@ mptcp_connorder_helper(struct mptsub *mpts)
        struct tcpcb *tp = NULL;
 
        socket_lock(so, 0);
        struct tcpcb *tp = NULL;
 
        socket_lock(so, 0);
-       
+
        tp = intotcpcb(sotoinpcb(so));
        tp->t_mpflags |= TMPF_SND_MPPRIO;
        if (mpts->mpts_flags & MPTSF_PREFERRED)
        tp = intotcpcb(sotoinpcb(so));
        tp->t_mpflags |= TMPF_SND_MPPRIO;
        if (mpts->mpts_flags & MPTSF_PREFERRED)
@@ -811,7 +844,7 @@ mptcp_disconnectx(struct mptses *mpte, sae_associd_t aid, sae_connid_t cid)
        } else {
                bool disconnect_embryonic_subflows = false;
                struct socket *so = NULL;
        } else {
                bool disconnect_embryonic_subflows = false;
                struct socket *so = NULL;
-               
+
                TAILQ_FOREACH(mpts, &mpte->mpte_subflows, mpts_entry) {
                        if (mpts->mpts_connid != cid)
                                continue;
                TAILQ_FOREACH(mpts, &mpte->mpte_subflows, mpts_entry) {
                        if (mpts->mpts_connid != cid)
                                continue;
@@ -1108,7 +1141,8 @@ mptcp_usr_send(struct socket *mp_so, int prus_flags, struct mbuf *m,
        mpte = mptompte(mpp);
        VERIFY(mpte != NULL);
 
        mpte = mptompte(mpp);
        VERIFY(mpte != NULL);
 
-       if (!(mp_so->so_state & SS_ISCONNECTED)) {
+       if (!(mp_so->so_state & SS_ISCONNECTED) &&
+            (!(mp_so->so_flags1 & SOF1_PRECONNECT_DATA))) {
                error = ENOTCONN;
                goto out;
        }
                error = ENOTCONN;
                goto out;
        }
@@ -1118,13 +1152,20 @@ mptcp_usr_send(struct socket *mp_so, int prus_flags, struct mbuf *m,
        (void) sbappendstream(&mp_so->so_snd, m);
        m = NULL;
 
        (void) sbappendstream(&mp_so->so_snd, m);
        m = NULL;
 
-       if (mpte != NULL) {
-               /*
-                * XXX: adi@apple.com
-                *
-                * PRUS_MORETOCOME could be set, but we don't check it now.
-                */
-               error = mptcp_output(mpte);
+       /*
+        * XXX: adi@apple.com
+        *
+        * PRUS_MORETOCOME could be set, but we don't check it now.
+        */
+       error = mptcp_output(mpte);
+       if (error != 0)
+               goto out;
+
+       if (mp_so->so_state & SS_ISCONNECTING) {
+               if (mp_so->so_state & SS_NBIO)
+                       error = EWOULDBLOCK;
+               else
+                       error = sbwait(&mp_so->so_snd);
        }
 
 out:
        }
 
 out:
@@ -1377,6 +1418,10 @@ out:
        if (control != NULL)
                m_freem(control);
 
        if (control != NULL)
                m_freem(control);
 
+       /* clear SOF1_PRECONNECT_DATA after one write */
+       if (mp_so->so_flags1 & SOF1_PRECONNECT_DATA)
+               mp_so->so_flags1 &= ~SOF1_PRECONNECT_DATA;
+
        return (error);
 }
 
        return (error);
 }
 
@@ -1453,6 +1498,7 @@ mptcp_usr_socheckopt(struct socket *mp_so, struct sockopt *sopt)
        case SO_FLUSH:                          /* MP + subflow */
        case SO_MPTCP_FASTJOIN:                 /* MP + subflow */
        case SO_NOWAKEFROMSLEEP:
        case SO_FLUSH:                          /* MP + subflow */
        case SO_MPTCP_FASTJOIN:                 /* MP + subflow */
        case SO_NOWAKEFROMSLEEP:
+       case SO_NOAPNFALLBK:
                /*
                 * Tell the caller that these options are to be processed;
                 * these will also be recorded later by mptcp_setopt().
                /*
                 * Tell the caller that these options are to be processed;
                 * these will also be recorded later by mptcp_setopt().
@@ -1631,6 +1677,7 @@ mptcp_setopt(struct mptses *mpte, struct sockopt *sopt)
                case SO_RESTRICTIONS:
                case SO_NOWAKEFROMSLEEP:
                case SO_MPTCP_FASTJOIN:
                case SO_RESTRICTIONS:
                case SO_NOWAKEFROMSLEEP:
                case SO_MPTCP_FASTJOIN:
+               case SO_NOAPNFALLBK:
                        /* record it */
                        break;
                case SO_FLUSH:
                        /* record it */
                        break;
                case SO_FLUSH:
@@ -2024,6 +2071,9 @@ mptcp_sopt2str(int level, int optname, char *dst, int size)
                case SO_MPTCP_FASTJOIN:
                        o = "SO_MPTCP_FASTJOIN";
                        break;
                case SO_MPTCP_FASTJOIN:
                        o = "SO_MPTCP_FASTJOIN";
                        break;
+               case SO_NOAPNFALLBK:
+                       o = "SO_NOAPNFALLBK";
+                       break;
                }
                break;
        case IPPROTO_TCP:
                }
                break;
        case IPPROTO_TCP:
@@ -2054,3 +2104,37 @@ mptcp_sopt2str(int level, int optname, char *dst, int size)
        (void) snprintf(dst, size, "<%s,%s>", l, o);
        return (dst);
 }
        (void) snprintf(dst, size, "<%s,%s>", l, o);
        return (dst);
 }
+
+static int
+mptcp_usr_preconnect(struct socket *mp_so)
+{
+       struct mptsub *mpts = NULL;
+       struct mppcb *mpp = sotomppcb(mp_so);
+       struct mptses *mpte;
+       struct socket *so;
+       struct tcpcb *tp = NULL;
+
+       mpte = mptompte(mpp);
+       VERIFY(mpte != NULL);
+       MPTE_LOCK_ASSERT_HELD(mpte);    /* same as MP socket lock */
+
+       mpts = mptcp_get_subflow(mpte, NULL, NULL);
+       if (mpts == NULL) {
+               mptcplog((LOG_ERR, "MPTCP Socket: "
+                   "%s: mp_so 0x%llx invalid preconnect ", __func__,
+                   (u_int64_t)VM_KERNEL_ADDRPERM(mp_so)),
+                   MPTCP_SOCKET_DBG, MPTCP_LOGLVL_ERR);
+               return (EINVAL);
+       }
+       MPTS_LOCK(mpts);
+       mpts->mpts_flags &= ~MPTSF_TFO_REQD;
+       so = mpts->mpts_socket;
+       socket_lock(so, 0);
+       tp = intotcpcb(sotoinpcb(so));
+       tp->t_mpflags &= ~TMPF_TFO_REQUEST;
+       int error = tcp_output(sototcpcb(so));
+       socket_unlock(so, 0);
+       MPTS_UNLOCK(mpts);
+       mp_so->so_flags1 &= ~SOF1_PRECONNECT_DATA;
+       return (error);
+}