]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/kern/kern_control.c
xnu-6153.41.3.tar.gz
[apple/xnu.git] / bsd / kern / kern_control.c
index dede2e6e85d521e62f27f945fc06286bb2f615c0..e41d1f103d8528c6e6fe6a870a0bb5bde632ca56 100644 (file)
@@ -93,6 +93,7 @@ struct ctl_cb {
        void                    *userdata;
        struct sockaddr_ctl     sac;
        u_int32_t               usecount;
+       u_int32_t               kcb_usecount;
 };
 
 #ifndef ROUNDUP64
@@ -351,6 +352,27 @@ ctl_sofreelastref(struct socket *so)
        return 0;
 }
 
+/*
+ * Use this function to serialize calls into the kctl subsystem
+ */
+static void
+ctl_kcb_increment_use_count(struct ctl_cb *kcb, lck_mtx_t *mutex_held)
+{
+       LCK_MTX_ASSERT(mutex_held, LCK_MTX_ASSERT_OWNED);
+       while (kcb->kcb_usecount > 0) {
+               msleep(&kcb->kcb_usecount, mutex_held, PSOCK | PCATCH, "kcb_usecount", NULL);
+       }
+       kcb->kcb_usecount++;
+}
+
+static void
+clt_kcb_decrement_use_count(struct ctl_cb *kcb)
+{
+       assert(kcb->kcb_usecount != 0);
+       kcb->kcb_usecount--;
+       wakeup_one((caddr_t)&kcb->kcb_usecount);
+}
+
 static int
 ctl_detach(struct socket *so)
 {
@@ -360,6 +382,9 @@ ctl_detach(struct socket *so)
                return 0;
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        if (kcb->kctl != NULL && kcb->kctl->bind != NULL &&
            kcb->userdata != NULL && !(so->so_state & SS_ISCONNECTED)) {
                // The unit was bound, but not connected
@@ -374,6 +399,7 @@ ctl_detach(struct socket *so)
 
        soisdisconnected(so);
        so->so_flags |= SOF_PCBCLEARING;
+       clt_kcb_decrement_use_count(kcb);
        return 0;
 }
 
@@ -522,9 +548,12 @@ ctl_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
                panic("ctl_bind so_pcb null\n");
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        error = ctl_setup_kctl(so, nam, p);
        if (error) {
-               return error;
+               goto out;
        }
 
        if (kcb->kctl == NULL) {
@@ -532,13 +561,16 @@ ctl_bind(struct socket *so, struct sockaddr *nam, struct proc *p)
        }
 
        if (kcb->kctl->bind == NULL) {
-               return EINVAL;
+               error = EINVAL;
+               goto out;
        }
 
        socket_unlock(so, 0);
        error = (*kcb->kctl->bind)(kcb->kctl->kctlref, &kcb->sac, &kcb->userdata);
        socket_lock(so, 0);
 
+out:
+       clt_kcb_decrement_use_count(kcb);
        return error;
 }
 
@@ -552,9 +584,12 @@ ctl_connect(struct socket *so, struct sockaddr *nam, struct proc *p)
                panic("ctl_connect so_pcb null\n");
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        error = ctl_setup_kctl(so, nam, p);
        if (error) {
-               return error;
+               goto out;
        }
 
        if (kcb->kctl == NULL) {
@@ -596,6 +631,8 @@ end:
                kctlstat.kcs_conn_fail++;
                lck_mtx_unlock(ctl_mtx);
        }
+out:
+       clt_kcb_decrement_use_count(kcb);
        return error;
 }
 
@@ -605,6 +642,8 @@ ctl_disconnect(struct socket *so)
        struct ctl_cb   *kcb = (struct ctl_cb *)so->so_pcb;
 
        if ((kcb = (struct ctl_cb *)so->so_pcb)) {
+               lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+               ctl_kcb_increment_use_count(kcb, mtx_held);
                struct kctl             *kctl = kcb->kctl;
 
                if (kctl && kctl->disconnect) {
@@ -628,6 +667,7 @@ ctl_disconnect(struct socket *so)
                kctlstat.kcs_gencnt++;
                lck_mtx_unlock(ctl_mtx);
                socket_lock(so, 0);
+               clt_kcb_decrement_use_count(kcb);
        }
        return 0;
 }
@@ -694,11 +734,20 @@ ctl_sbrcv_trim(struct socket *so)
 static int
 ctl_usr_rcvd(struct socket *so, int flags)
 {
+       int                     error = 0;
        struct ctl_cb           *kcb = (struct ctl_cb *)so->so_pcb;
        struct kctl                     *kctl;
 
+       if (kcb == NULL) {
+               return ENOTCONN;
+       }
+
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        if ((kctl = kcb->kctl) == NULL) {
-               return EINVAL;
+               error = EINVAL;
+               goto out;
        }
 
        if (kctl->rcvd) {
@@ -709,7 +758,9 @@ ctl_usr_rcvd(struct socket *so, int flags)
 
        ctl_sbrcv_trim(so);
 
-       return 0;
+out:
+       clt_kcb_decrement_use_count(kcb);
+       return error;
 }
 
 static int
@@ -730,6 +781,9 @@ ctl_send(struct socket *so, int flags, struct mbuf *m,
                error = ENOTCONN;
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        if (error == 0 && (kctl = kcb->kctl) == NULL) {
                error = EINVAL;
        }
@@ -749,6 +803,8 @@ ctl_send(struct socket *so, int flags, struct mbuf *m,
        if (error != 0) {
                OSIncrementAtomic64((SInt64 *)&kctlstat.kcs_send_fail);
        }
+       clt_kcb_decrement_use_count(kcb);
+
        return error;
 }
 
@@ -769,6 +825,9 @@ ctl_send_list(struct socket *so, int flags, struct mbuf *m,
                error = ENOTCONN;
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        if (error == 0 && (kctl = kcb->kctl) == NULL) {
                error = EINVAL;
        }
@@ -808,6 +867,8 @@ ctl_send_list(struct socket *so, int flags, struct mbuf *m,
        if (error != 0) {
                OSIncrementAtomic64((SInt64 *)&kctlstat.kcs_send_list_fail);
        }
+       clt_kcb_decrement_use_count(kcb);
+
        return error;
 }
 
@@ -1234,16 +1295,21 @@ ctl_ctloutput(struct socket *so, struct sockopt *sopt)
                return EINVAL;
        }
 
+       lck_mtx_t *mtx_held = socket_getlock(so, PR_F_WILLUNLOCK);
+       ctl_kcb_increment_use_count(kcb, mtx_held);
+
        switch (sopt->sopt_dir) {
        case SOPT_SET:
                if (kctl->setopt == NULL) {
-                       return ENOTSUP;
+                       error = ENOTSUP;
+                       goto out;
                }
                if (sopt->sopt_valsize != 0) {
                        MALLOC(data, void *, sopt->sopt_valsize, M_TEMP,
                            M_WAITOK | M_ZERO);
                        if (data == NULL) {
-                               return ENOMEM;
+                               error = ENOMEM;
+                               goto out;
                        }
                        error = sooptcopyin(sopt, data,
                            sopt->sopt_valsize, sopt->sopt_valsize);
@@ -1263,14 +1329,16 @@ ctl_ctloutput(struct socket *so, struct sockopt *sopt)
 
        case SOPT_GET:
                if (kctl->getopt == NULL) {
-                       return ENOTSUP;
+                       error = ENOTSUP;
+                       goto out;
                }
 
                if (sopt->sopt_valsize && sopt->sopt_val) {
                        MALLOC(data, void *, sopt->sopt_valsize, M_TEMP,
                            M_WAITOK | M_ZERO);
                        if (data == NULL) {
-                               return ENOMEM;
+                               error = ENOMEM;
+                               goto out;
                        }
                        /*
                         * 4108337 - copy user data in case the
@@ -1306,6 +1374,9 @@ ctl_ctloutput(struct socket *so, struct sockopt *sopt)
                }
                break;
        }
+
+out:
+       clt_kcb_decrement_use_count(kcb);
        return error;
 }