]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/mptcp_usrreq.c
xnu-7195.81.3.tar.gz
[apple/xnu.git] / bsd / netinet / mptcp_usrreq.c
index 0012e449760500252500d5aefc133147d738debd..ff16b486f8468a8e7f07a85194eddb7e2c9c83d4 100644 (file)
@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright (c) 2012-2017 Apple Inc. All rights reserved.
+ * Copyright (c) 2012-2020 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  *
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  *
@@ -126,16 +126,12 @@ mptcp_usr_attach(struct socket *mp_so, int proto, struct proc *p)
        VERIFY(mpsotomppcb(mp_so) == NULL);
 
        error = mptcp_attach(mp_so, p);
        VERIFY(mpsotomppcb(mp_so) == NULL);
 
        error = mptcp_attach(mp_so, p);
-       if (error != 0) {
+       if (error) {
                goto out;
        }
                goto out;
        }
-       /*
-        * XXX: adi@apple.com
-        *
-        * Might want to use a different SO_LINGER timeout than TCP's?
-        */
+
        if ((mp_so->so_options & SO_LINGER) && mp_so->so_linger == 0) {
        if ((mp_so->so_options & SO_LINGER) && mp_so->so_linger == 0) {
-               mp_so->so_linger = TCP_LINGERTIME * hz;
+               mp_so->so_linger = (short)(TCP_LINGERTIME * hz);
        }
 out:
        return error;
        }
 out:
        return error;
@@ -222,7 +218,7 @@ out:
 }
 
 static int
 }
 
 static int
-mptcp_entitlement_check(struct socket *mp_so)
+mptcp_entitlement_check(struct socket *mp_so, uint8_t svctype)
 {
        struct mptses *mpte = mpsotompte(mp_so);
 
 {
        struct mptses *mpte = mpsotompte(mp_so);
 
@@ -247,36 +243,17 @@ mptcp_entitlement_check(struct socket *mp_so)
                return 0;
        }
 
                return 0;
        }
 
-       /* Now, take a look at exceptions configured through sysctl */
-#if (DEVELOPMENT || DEBUG)
-       if (mptcp_disable_entitlements) {
-               return 0;
-       }
-#endif
-
-       if (mpte->mpte_svctype == MPTCP_SVCTYPE_AGGREGATE) {
+       if (svctype == MPTCP_SVCTYPE_AGGREGATE) {
                if (mptcp_developer_mode) {
                        return 0;
                }
 
                if (mptcp_developer_mode) {
                        return 0;
                }
 
-               goto deny;
-       }
-
-       /* Second, check for regular users that are within the data-limits */
-       if (soopt_cred_check(mp_so, PRIV_NET_PRIVILEGED_MULTIPATH, TRUE, FALSE) == 0) {
-               return 0;
-       }
-
-       if (mp_so->so_flags & SOF_DELEGATED &&
-           soopt_cred_check(mp_so, PRIV_NET_PRIVILEGED_MULTIPATH, TRUE, TRUE) == 0) {
-               return 0;
+               os_log_error(mptcp_log_handle, "%s - %lx: MPTCP prohibited on svc %u\n",
+                   __func__, (unsigned long)VM_KERNEL_ADDRPERM(mpte), mpte->mpte_svctype);
+               return -1;
        }
 
        }
 
-deny:
-       os_log_error(mptcp_log_handle, "%s - %lx: MPTCP prohibited on svc %u\n",
-           __func__, (unsigned long)VM_KERNEL_ADDRPERM(mpte), mpte->mpte_svctype);
-
-       return -1;
+       return 0;
 }
 
 /*
 }
 
 /*
@@ -354,7 +331,7 @@ mptcp_usr_connectx(struct socket *mp_so, struct sockaddr *src,
        }
 
        if (!(mpte->mpte_flags & MPTE_SVCTYPE_CHECKED)) {
        }
 
        if (!(mpte->mpte_flags & MPTE_SVCTYPE_CHECKED)) {
-               if (mptcp_entitlement_check(mp_so) < 0) {
+               if (mptcp_entitlement_check(mp_so, mpte->mpte_svctype) < 0) {
                        error = EPERM;
                        goto out;
                }
                        error = EPERM;
                        goto out;
                }
@@ -619,13 +596,18 @@ mptcp_getconninfo(struct mptses *mpte, sae_connid_t *cid, uint32_t *flags,
                return 0;
        } else {
                /* Per-interface stats */
                return 0;
        } else {
                /* Per-interface stats */
-               const struct mptsub *mpts, *orig_mpts;
+               const struct mptsub *mpts, *orig_mpts = NULL;
                struct conninfo_tcp tcp_ci;
                const struct inpcb *inp;
                struct socket *so;
                int error = 0;
                int index;
 
                struct conninfo_tcp tcp_ci;
                const struct inpcb *inp;
                struct socket *so;
                int error = 0;
                int index;
 
+               /* cid is thus an ifindex - range-check first! */
+               if (*cid > USHRT_MAX) {
+                       return EINVAL;
+               }
+
                bzero(&tcp_ci, sizeof(tcp_ci));
 
                /* First, get a subflow to fill in the "regular" info. */
                bzero(&tcp_ci, sizeof(tcp_ci));
 
                /* First, get a subflow to fill in the "regular" info. */
@@ -741,7 +723,7 @@ interface_info:
                         * nor anything in the stats, return EINVAL. Because the
                         * ifindex belongs to something that doesn't exist.
                         */
                         * nor anything in the stats, return EINVAL. Because the
                         * ifindex belongs to something that doesn't exist.
                         */
-                       index = mptcpstats_get_index_by_ifindex(mpte->mpte_itfstats, *cid, false);
+                       index = mptcpstats_get_index_by_ifindex(mpte->mpte_itfstats, (u_short)(*cid), false);
                        if (index == -1) {
                                os_log_error(mptcp_log_handle,
                                    "%s - %lx: Asking for too many ifindex: %u subcount %u, mpts? %s\n",
                        if (index == -1) {
                                os_log_error(mptcp_log_handle,
                                    "%s - %lx: Asking for too many ifindex: %u subcount %u, mpts? %s\n",
@@ -825,7 +807,7 @@ mptcp_usr_control(struct socket *mp_so, u_long cmd, caddr_t data,
                struct so_aidreq64 aidr;
                bcopy(data, &aidr, sizeof(aidr));
                error = mptcp_getassocids(mpte, &aidr.sar_cnt,
                struct so_aidreq64 aidr;
                bcopy(data, &aidr, sizeof(aidr));
                error = mptcp_getassocids(mpte, &aidr.sar_cnt,
-                   aidr.sar_aidp);
+                   (user_addr_t)aidr.sar_aidp);
                if (error == 0) {
                        bcopy(&aidr, data, sizeof(aidr));
                }
                if (error == 0) {
                        bcopy(&aidr, data, sizeof(aidr));
                }
@@ -847,7 +829,7 @@ mptcp_usr_control(struct socket *mp_so, u_long cmd, caddr_t data,
                struct so_cidreq64 cidr;
                bcopy(data, &cidr, sizeof(cidr));
                error = mptcp_getconnids(mpte, cidr.scr_aid, &cidr.scr_cnt,
                struct so_cidreq64 cidr;
                bcopy(data, &cidr, sizeof(cidr));
                error = mptcp_getconnids(mpte, cidr.scr_aid, &cidr.scr_cnt,
-                   cidr.scr_cidp);
+                   (user_addr_t)cidr.scr_cidp);
                if (error == 0) {
                        bcopy(&cidr, data, sizeof(cidr));
                }
                if (error == 0) {
                        bcopy(&cidr, data, sizeof(cidr));
                }
@@ -873,8 +855,9 @@ mptcp_usr_control(struct socket *mp_so, u_long cmd, caddr_t data,
                bcopy(data, &cifr, sizeof(cifr));
                error = mptcp_getconninfo(mpte, &cifr.scir_cid,
                    &cifr.scir_flags, &cifr.scir_ifindex, &cifr.scir_error,
                bcopy(data, &cifr, sizeof(cifr));
                error = mptcp_getconninfo(mpte, &cifr.scir_cid,
                    &cifr.scir_flags, &cifr.scir_ifindex, &cifr.scir_error,
-                   cifr.scir_src, &cifr.scir_src_len, cifr.scir_dst,
-                   &cifr.scir_dst_len, &cifr.scir_aux_type, cifr.scir_aux_data,
+                   (user_addr_t)cifr.scir_src, &cifr.scir_src_len,
+                   (user_addr_t)cifr.scir_dst, &cifr.scir_dst_len,
+                   &cifr.scir_aux_type, (user_addr_t)cifr.scir_aux_data,
                    &cifr.scir_aux_len);
                if (error == 0) {
                        bcopy(&cifr, data, sizeof(cifr));
                    &cifr.scir_aux_len);
                if (error == 0) {
                        bcopy(&cifr, data, sizeof(cifr));
@@ -1141,12 +1124,11 @@ out:
  * Copy the contents of uio into a properly sized mbuf chain.
  */
 static int
  * Copy the contents of uio into a properly sized mbuf chain.
  */
 static int
-mptcp_uiotombuf(struct uio *uio, int how, int space, uint32_t align,
-    struct mbuf **top)
+mptcp_uiotombuf(struct uio *uio, int how, user_ssize_t space, struct mbuf **top)
 {
        struct mbuf *m, *mb, *nm = NULL, *mtail = NULL;
 {
        struct mbuf *m, *mb, *nm = NULL, *mtail = NULL;
-       user_ssize_t resid, tot, len, progress; /* must be user_ssize_t */
-       int error;
+       int progress, len, error;
+       user_ssize_t resid, tot;
 
        VERIFY(top != NULL && *top == NULL);
 
 
        VERIFY(top != NULL && *top == NULL);
 
@@ -1156,24 +1138,17 @@ mptcp_uiotombuf(struct uio *uio, int how, int space, uint32_t align,
         */
        resid = uio_resid(uio);
        if (space > 0) {
         */
        resid = uio_resid(uio);
        if (space > 0) {
-               tot = imin(resid, space);
+               tot = MIN(resid, space);
        } else {
                tot = resid;
        }
 
        } else {
                tot = resid;
        }
 
-       /*
-        * The smallest unit is a single mbuf with pkthdr.
-        * We can't align past it.
-        */
-       if (align >= MHLEN) {
+       if (tot < 0 || tot > INT_MAX) {
                return EINVAL;
        }
 
                return EINVAL;
        }
 
-       /*
-        * Give us the full allocation or nothing.
-        * If space is zero return the smallest empty mbuf.
-        */
-       if ((len = tot + align) == 0) {
+       len = (int)tot;
+       if (len == 0) {
                len = 1;
        }
 
                len = 1;
        }
 
@@ -1214,12 +1189,12 @@ mptcp_uiotombuf(struct uio *uio, int how, int space, uint32_t align,
        }
 
        m = nm;
        }
 
        m = nm;
-       m->m_data += align;
 
        progress = 0;
        /* Fill all mbufs with uio data and update header information. */
        for (mb = m; mb != NULL; mb = mb->m_next) {
 
        progress = 0;
        /* Fill all mbufs with uio data and update header information. */
        for (mb = m; mb != NULL; mb = mb->m_next) {
-               len = imin(M_TRAILINGSPACE(mb), tot - progress);
+               /* tot >= 0 && tot <= INT_MAX (see above) */
+               len = MIN((int)M_TRAILINGSPACE(mb), (int)(tot - progress));
 
                error = uiomove(mtod(mb, char *), len, uio);
                if (error != 0) {
 
                error = uiomove(mtod(mb, char *), len, uio);
                if (error != 0) {
@@ -1246,8 +1221,7 @@ mptcp_usr_sosend(struct socket *mp_so, struct sockaddr *addr, struct uio *uio,
     struct mbuf *top, struct mbuf *control, int flags)
 {
 #pragma unused(addr)
     struct mbuf *top, struct mbuf *control, int flags)
 {
 #pragma unused(addr)
-       int32_t space;
-       user_ssize_t resid;
+       user_ssize_t resid, space;
        int error, sendflags;
        struct proc *p = current_proc();
        int sblocked = 0;
        int error, sendflags;
        struct proc *p = current_proc();
        int sblocked = 0;
@@ -1266,8 +1240,7 @@ mptcp_usr_sosend(struct socket *mp_so, struct sockaddr *addr, struct uio *uio,
        VERIFY(mp_so->so_type == SOCK_STREAM);
        VERIFY(!(mp_so->so_flags & SOF_MP_SUBFLOW));
 
        VERIFY(mp_so->so_type == SOCK_STREAM);
        VERIFY(!(mp_so->so_flags & SOF_MP_SUBFLOW));
 
-       if ((flags & (MSG_OOB | MSG_DONTROUTE)) ||
-           (mp_so->so_flags & SOF_ENABLE_MSGS)) {
+       if (flags & (MSG_OOB | MSG_DONTROUTE)) {
                error = EOPNOTSUPP;
                socket_unlock(mp_so, 1);
                goto out;
                error = EOPNOTSUPP;
                socket_unlock(mp_so, 1);
                goto out;
@@ -1280,7 +1253,8 @@ mptcp_usr_sosend(struct socket *mp_so, struct sockaddr *addr, struct uio *uio,
         * hand, a negative resid causes us to loop sending 0-length
         * segments to the protocol.
         */
         * hand, a negative resid causes us to loop sending 0-length
         * segments to the protocol.
         */
-       if (resid < 0 || (flags & MSG_EOR) || control != NULL) {
+       if (resid < 0 || resid > INT_MAX ||
+           (flags & MSG_EOR) || control != NULL) {
                error = EINVAL;
                socket_unlock(mp_so, 1);
                goto out;
                error = EINVAL;
                socket_unlock(mp_so, 1);
                goto out;
@@ -1290,7 +1264,7 @@ mptcp_usr_sosend(struct socket *mp_so, struct sockaddr *addr, struct uio *uio,
 
        do {
                error = sosendcheck(mp_so, NULL, resid, 0, 0, flags,
 
        do {
                error = sosendcheck(mp_so, NULL, resid, 0, 0, flags,
-                   &sblocked, NULL);
+                   &sblocked);
                if (error != 0) {
                        goto release;
                }
                if (error != 0) {
                        goto release;
                }
@@ -1301,7 +1275,7 @@ mptcp_usr_sosend(struct socket *mp_so, struct sockaddr *addr, struct uio *uio,
                        /*
                         * Copy the data from userland into an mbuf chain.
                         */
                        /*
                         * Copy the data from userland into an mbuf chain.
                         */
-                       error = mptcp_uiotombuf(uio, M_WAITOK, space, 0, &top);
+                       error = mptcp_uiotombuf(uio, M_WAITOK, space, &top);
                        if (error != 0) {
                                socket_lock(mp_so, 0);
                                goto release;
                        if (error != 0) {
                                socket_lock(mp_so, 0);
                                goto release;
@@ -1713,13 +1687,12 @@ mptcp_setopt(struct mptses *mpte, struct sockopt *sopt)
                                goto err_out;
                        }
 
                                goto err_out;
                        }
 
-                       mpte->mpte_svctype = optval;
-
-                       if (mptcp_entitlement_check(mp_so) < 0) {
+                       if (mptcp_entitlement_check(mp_so, (uint8_t)optval) < 0) {
                                error = EACCES;
                                goto err_out;
                        }
 
                                error = EACCES;
                                goto err_out;
                        }
 
+                       mpte->mpte_svctype = (uint8_t)optval;
                        mpte->mpte_flags |= MPTE_SVCTYPE_CHECKED;
 
                        goto out;
                        mpte->mpte_flags |= MPTE_SVCTYPE_CHECKED;
 
                        goto out;
@@ -1736,7 +1709,7 @@ mptcp_setopt(struct mptses *mpte, struct sockopt *sopt)
                                goto err_out;
                        }
 
                                goto err_out;
                        }
 
-                       mpte->mpte_alternate_port = optval;
+                       mpte->mpte_alternate_port = (uint16_t)optval;
 
                        goto out;
                case MPTCP_FORCE_ENABLE:
 
                        goto out;
                case MPTCP_FORCE_ENABLE:
@@ -1831,7 +1804,7 @@ mptcp_setopt(struct mptses *mpte, struct sockopt *sopt)
        if (rec) {
                /* search for an existing one; if not found, allocate */
                if ((mpo = mptcp_sopt_find(mpte, sopt)) == NULL) {
        if (rec) {
                /* search for an existing one; if not found, allocate */
                if ((mpo = mptcp_sopt_find(mpte, sopt)) == NULL) {
-                       mpo = mptcp_sopt_alloc(M_WAITOK);
+                       mpo = mptcp_sopt_alloc(Z_WAITOK);
                }
 
                if (mpo == NULL) {
                }
 
                if (mpo == NULL) {
@@ -1936,7 +1909,7 @@ mptcp_fill_info(struct mptses *mpte, struct tcp_info *ti)
 
        bzero(ti, sizeof(*ti));
 
 
        bzero(ti, sizeof(*ti));
 
-       ti->tcpi_state = mp_tp->mpt_state;
+       ti->tcpi_state = (uint8_t)mp_tp->mpt_state;
        /* tcpi_options */
        /* tcpi_snd_wscale */
        /* tcpi_rcv_wscale */
        /* tcpi_options */
        /* tcpi_snd_wscale */
        /* tcpi_rcv_wscale */
@@ -1957,8 +1930,8 @@ mptcp_fill_info(struct mptses *mpte, struct tcp_info *ti)
        /* tcpi_snd_cwnd */
        /* tcpi_rcv_space */
        ti->tcpi_snd_wnd = mp_tp->mpt_sndwnd;
        /* tcpi_snd_cwnd */
        /* tcpi_rcv_space */
        ti->tcpi_snd_wnd = mp_tp->mpt_sndwnd;
-       ti->tcpi_snd_nxt = mp_tp->mpt_sndnxt;
-       ti->tcpi_rcv_nxt = mp_tp->mpt_rcvnxt;
+       ti->tcpi_snd_nxt = (uint32_t)mp_tp->mpt_sndnxt;
+       ti->tcpi_rcv_nxt = (uint32_t)mp_tp->mpt_rcvnxt;
        if (acttp) {
                ti->tcpi_last_outif = (acttp->t_inpcb->inp_last_outifp == NULL) ? 0 :
                    acttp->t_inpcb->inp_last_outifp->if_index;
        if (acttp) {
                ti->tcpi_last_outif = (acttp->t_inpcb->inp_last_outifp == NULL) ? 0 :
                    acttp->t_inpcb->inp_last_outifp->if_index;
@@ -2028,6 +2001,7 @@ mptcp_getopt(struct mptses *mpte, struct sockopt *sopt)
        case PERSIST_TIMEOUT:
                /* Only case for which we have a non-zero default */
                optval = tcp_max_persist_timeout;
        case PERSIST_TIMEOUT:
                /* Only case for which we have a non-zero default */
                optval = tcp_max_persist_timeout;
+               OS_FALLTHROUGH;
        case TCP_NODELAY:
        case TCP_RXT_FINDROP:
        case TCP_KEEPALIVE:
        case TCP_NODELAY:
        case TCP_RXT_FINDROP:
        case TCP_KEEPALIVE: