]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/kern/kern_control.c
xnu-2422.1.72.tar.gz
[apple/xnu.git] / bsd / kern / kern_control.c
index a76088eb4889c49916e77abe78d1c7cfd1c86bcc..5a4cfd5f7880326e5b423172fcdf4bcf4a969c75 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1999-2008 Apple Computer, Inc. All rights reserved.
+ * Copyright (c) 1999-2012 Apple Inc. All rights reserved.
  *
  * @APPLE_OSREFERENCE_LICENSE_HEADER_START@
  * 
@@ -87,10 +87,12 @@ static int ctl_send(struct socket *, int, struct mbuf *,
             struct sockaddr *, struct mbuf *, struct proc *);
 static int ctl_ctloutput(struct socket *, struct sockopt *);
 static int ctl_peeraddr(struct socket *so, struct sockaddr **nam);
+static int ctl_usr_rcvd(struct socket *so, int flags);
 
 static struct kctl *ctl_find_by_name(const char *);
 static struct kctl *ctl_find_by_id_unit(u_int32_t id, u_int32_t unit);
 
+static struct socket *kcb_find_socket(struct kctl *, u_int32_t unit);
 static struct ctl_cb *kcb_find(struct kctl *, u_int32_t unit);
 static void ctl_post_msg(u_int32_t event_code, u_int32_t id);
 
@@ -98,103 +100,84 @@ static int ctl_lock(struct socket *, int, void *);
 static int ctl_unlock(struct socket *, int, void *);
 static lck_mtx_t * ctl_getlock(struct socket *, int);
 
-static struct pr_usrreqs ctl_usrreqs =
-{
-       pru_abort_notsupp, pru_accept_notsupp, ctl_attach, pru_bind_notsupp,
-       ctl_connect, pru_connect2_notsupp, ctl_ioctl, ctl_detach,
-       ctl_disconnect, pru_listen_notsupp, ctl_peeraddr,
-       pru_rcvd_notsupp, pru_rcvoob_notsupp, ctl_send,
-       pru_sense_null, pru_shutdown_notsupp, pru_sockaddr_notsupp,
-       sosend, soreceive, pru_sopoll_notsupp
+static struct pr_usrreqs ctl_usrreqs = {
+       .pru_attach =           ctl_attach,
+       .pru_connect =          ctl_connect,
+       .pru_control =          ctl_ioctl,
+       .pru_detach =           ctl_detach,
+       .pru_disconnect =       ctl_disconnect,
+       .pru_peeraddr =         ctl_peeraddr,
+       .pru_rcvd =             ctl_usr_rcvd,
+       .pru_send =             ctl_send,
+       .pru_sosend =           sosend,
+       .pru_soreceive =        soreceive,
 };
 
-static struct protosw kctlswk_dgram =
+static struct protosw kctlsw[] = {
 {
-       SOCK_DGRAM, &systemdomain, SYSPROTO_CONTROL, 
-       PR_ATOMIC|PR_CONNREQUIRED|PR_PCBLOCK,
-       NULL, NULL, NULL, ctl_ctloutput,
-       NULL, NULL,
-       NULL, NULL, NULL, NULL, &ctl_usrreqs,
-       ctl_lock, ctl_unlock, ctl_getlock, { 0, 0 } , 0, { 0 }
-};
-
-static struct protosw kctlswk_stream =
+       .pr_type =              SOCK_DGRAM,
+       .pr_protocol =          SYSPROTO_CONTROL,
+       .pr_flags =             PR_ATOMIC|PR_CONNREQUIRED|PR_PCBLOCK|PR_WANTRCVD,
+       .pr_ctloutput =         ctl_ctloutput,
+       .pr_usrreqs =           &ctl_usrreqs,
+       .pr_lock =              ctl_lock,
+       .pr_unlock =            ctl_unlock,
+       .pr_getlock =           ctl_getlock,
+},
 {
-       SOCK_STREAM, &systemdomain, SYSPROTO_CONTROL, 
-       PR_CONNREQUIRED|PR_PCBLOCK,
-       NULL, NULL, NULL, ctl_ctloutput,
-       NULL, NULL,
-       NULL, NULL, NULL, NULL, &ctl_usrreqs,
-       ctl_lock, ctl_unlock, ctl_getlock, { 0, 0 } , 0, { 0 }
+       .pr_type =              SOCK_STREAM,
+       .pr_protocol =          SYSPROTO_CONTROL,
+       .pr_flags =             PR_CONNREQUIRED|PR_PCBLOCK|PR_WANTRCVD,
+       .pr_ctloutput =         ctl_ctloutput,
+       .pr_usrreqs =           &ctl_usrreqs,
+       .pr_lock =              ctl_lock,
+       .pr_unlock =            ctl_unlock,
+       .pr_getlock =           ctl_getlock,
+}
 };
 
+static int kctl_proto_count = (sizeof (kctlsw) / sizeof (struct protosw));
 
 /*
  * Install the protosw's for the Kernel Control manager.
  */
-__private_extern__ int
-kern_control_init(void)
+__private_extern__ void
+kern_control_init(struct domain *dp)
 {
-       int error = 0;
-       
+       struct protosw *pr;
+       int i;
+
+       VERIFY(!(dp->dom_flags & DOM_INITIALIZED));
+       VERIFY(dp == systemdomain);
+
        ctl_lck_grp_attr = lck_grp_attr_alloc_init();
-       if (ctl_lck_grp_attr == 0) {
-                       printf(": lck_grp_attr_alloc_init failed\n");
-                       error = ENOMEM;
-                       goto done;
+       if (ctl_lck_grp_attr == NULL) {
+               panic("%s: lck_grp_attr_alloc_init failed\n", __func__);
+               /* NOTREACHED */
        }
-                       
-       ctl_lck_grp = lck_grp_alloc_init("Kernel Control Protocol", ctl_lck_grp_attr);
-       if (ctl_lck_grp == 0) {
-                       printf("kern_control_init: lck_grp_alloc_init failed\n");
-                       error = ENOMEM;
-                       goto done;
+
+       ctl_lck_grp = lck_grp_alloc_init("Kernel Control Protocol",
+           ctl_lck_grp_attr);
+       if (ctl_lck_grp == NULL) {
+               panic("%s: lck_grp_alloc_init failed\n", __func__);
+               /* NOTREACHED */
        }
-       
+
        ctl_lck_attr = lck_attr_alloc_init();
-       if (ctl_lck_attr == 0) {
-                       printf("kern_control_init: lck_attr_alloc_init failed\n");
-                       error = ENOMEM;
-                       goto done;
+       if (ctl_lck_attr == NULL) {
+               panic("%s: lck_attr_alloc_init failed\n", __func__);
+               /* NOTREACHED */
        }
-       
+
        ctl_mtx = lck_mtx_alloc_init(ctl_lck_grp, ctl_lck_attr);
-       if (ctl_mtx == 0) {
-                       printf("kern_control_init: lck_mtx_alloc_init failed\n");
-                       error = ENOMEM;
-                       goto done;
+       if (ctl_mtx == NULL) {
+               panic("%s: lck_mtx_alloc_init failed\n", __func__);
+               /* NOTREACHED */
        }
        TAILQ_INIT(&ctl_head);
-       
-       error = net_add_proto(&kctlswk_dgram, &systemdomain);
-       if (error) {
-               log(LOG_WARNING, "kern_control_init: net_add_proto dgram failed (%d)\n", error);
-       }
-       error = net_add_proto(&kctlswk_stream, &systemdomain);
-       if (error) {
-               log(LOG_WARNING, "kern_control_init: net_add_proto stream failed (%d)\n", error);
-       }
-       
-       done:
-       if (error != 0) {
-               if (ctl_mtx) {
-                               lck_mtx_free(ctl_mtx, ctl_lck_grp);
-                               ctl_mtx = 0;
-               }
-               if (ctl_lck_grp) {
-                               lck_grp_free(ctl_lck_grp);
-                               ctl_lck_grp = 0;
-               }
-               if (ctl_lck_grp_attr) {
-                               lck_grp_attr_free(ctl_lck_grp_attr);
-                               ctl_lck_grp_attr = 0;
-               }
-               if (ctl_lck_attr) {
-                               lck_attr_free(ctl_lck_attr);
-                               ctl_lck_attr = 0;
-               }
-       }
-       return error;
+
+       for (i = 0, pr = &kctlsw[0]; i < kctl_proto_count; i++, pr++)
+               net_add_proto(pr, dp, 1);
 }
 
 static void
@@ -255,7 +238,7 @@ ctl_sofreelastref(struct socket *so)
         if ((kctl = kcb->kctl) != 0) {
             lck_mtx_lock(ctl_mtx);
             TAILQ_REMOVE(&kctl->kcb_head, kcb, next);
-            lck_mtx_lock(ctl_mtx);
+            lck_mtx_unlock(ctl_mtx);
        }
        kcb_delete(kcb);
     }
@@ -364,10 +347,16 @@ ctl_connect(struct socket *so, struct sockaddr *nam, __unused struct proc *p)
     error = (*kctl->connect)(kctl, &sa, &kcb->userdata);
        socket_lock(so, 0);
     if (error)
-               goto done;
+               goto end;
     
     soisconnected(so);
 
+end:
+       if (error && kctl->disconnect) {
+               socket_unlock(so, 0);
+               (*kctl->disconnect)(kctl, kcb->unit, kcb->userdata);
+               socket_lock(so, 0);
+       }
 done:
     if (error) {
         soisdisconnected(so);
@@ -393,12 +382,19 @@ ctl_disconnect(struct socket *so)
             (*kctl->disconnect)(kctl, kcb->unit, kcb->userdata);
             socket_lock(so, 0);
         }
+        
+        soisdisconnected(so);
+        
+               socket_unlock(so, 0);
         lck_mtx_lock(ctl_mtx);
         kcb->kctl = 0;
        kcb->unit = 0;
+       while (kcb->usecount != 0) {
+               msleep(&kcb->usecount, ctl_mtx, 0, "kcb->usecount", 0);
+       }
         TAILQ_REMOVE(&kctl->kcb_head, kcb, next);
-        soisdisconnected(so);
         lck_mtx_unlock(ctl_mtx);
+               socket_lock(so, 0);
     }
     return 0;
 }
@@ -428,25 +424,50 @@ ctl_peeraddr(struct socket *so, struct sockaddr **nam)
        return 0;
 }
 
+static int
+ctl_usr_rcvd(struct socket *so, int flags)
+{
+       struct ctl_cb           *kcb = (struct ctl_cb *)so->so_pcb;
+       struct kctl                     *kctl;
+
+       if ((kctl = kcb->kctl) == NULL) {
+               return EINVAL;
+       }
+
+       if (kctl->rcvd) {
+               socket_unlock(so, 0);
+               (*kctl->rcvd)(kctl, kcb->unit, kcb->userdata, flags);
+               socket_lock(so, 0);
+       }
+
+       return 0;
+}
+
 static int
 ctl_send(struct socket *so, int flags, struct mbuf *m,
-            __unused struct sockaddr *addr, __unused struct mbuf *control,
+            __unused struct sockaddr *addr, struct mbuf *control,
             __unused struct proc *p)
 {
        int             error = 0;
        struct ctl_cb   *kcb = (struct ctl_cb *)so->so_pcb;
        struct kctl             *kctl;
        
+       if (control) m_freem(control);
+       
        if (kcb == NULL)        /* sanity check */
-               return(ENOTCONN);
+               error = ENOTCONN;
        
-       if ((kctl = kcb->kctl) == NULL)
-               return(EINVAL);
+       if (error == 0 && (kctl = kcb->kctl) == NULL)
+               error = EINVAL;
                
-       if (kctl->send) {
+       if (error == 0 && kctl->send) {
                socket_unlock(so, 0);
                error = (*kctl->send)(kctl, kcb->unit, kcb->userdata, m, flags);
                socket_lock(so, 0);
+       } else {
+               m_freem(m);
+               if (error == 0)
+                       error = ENOTSUP;
        }
        return error;
 }
@@ -454,23 +475,18 @@ ctl_send(struct socket *so, int flags, struct mbuf *m,
 errno_t
 ctl_enqueuembuf(void *kctlref, u_int32_t unit, struct mbuf *m, u_int32_t flags)
 {
-       struct ctl_cb   *kcb;
        struct socket   *so;
        errno_t                 error = 0;
        struct kctl             *kctl = (struct kctl *)kctlref;
        
        if (kctl == NULL)
                return EINVAL;
-               
-       kcb = kcb_find(kctl, unit);
-       if (kcb == NULL)
-               return EINVAL;
        
-       so = (struct socket *)kcb->so;
-       if (so == NULL) 
+       so = kcb_find_socket(kctl, unit);
+       
+       if (so == NULL)
                return EINVAL;
        
-       socket_lock(so, 1);
        if (sbspace(&so->so_rcv) < m->m_pkthdr.len) {
                error = ENOBUFS;
                goto bye;
@@ -487,7 +503,6 @@ bye:
 errno_t
 ctl_enqueuedata(void *kctlref, u_int32_t unit, void *data, size_t len, u_int32_t flags)
 {
-       struct ctl_cb   *kcb;
        struct socket   *so;
        struct mbuf     *m;
        errno_t                 error = 0;
@@ -499,15 +514,10 @@ ctl_enqueuedata(void *kctlref, u_int32_t unit, void *data, size_t len, u_int32_t
        if (kctlref == NULL)
                return EINVAL;
                
-       kcb = kcb_find(kctl, unit);
-       if (kcb == NULL)
-               return EINVAL;
-       
-       so = (struct socket *)kcb->so;
-       if (so == NULL) 
+       so = kcb_find_socket(kctl, unit);
+       if (so == NULL)
                return EINVAL;
        
-       socket_lock(so, 1);
        if (sbspace(&so->so_rcv) < (int)len) {
                error = ENOBUFS;
                goto bye;
@@ -545,27 +555,21 @@ bye:
 errno_t 
 ctl_getenqueuespace(kern_ctl_ref kctlref, u_int32_t unit, size_t *space)
 {
-       struct ctl_cb   *kcb;
        struct kctl             *kctl = (struct kctl *)kctlref;
        struct socket   *so;
        long avail;
        
        if (kctlref == NULL || space == NULL)
                return EINVAL;
-               
-       kcb = kcb_find(kctl, unit);
-       if (kcb == NULL)
-               return EINVAL;
        
-       so = (struct socket *)kcb->so;
-       if (so == NULL) 
+       so = kcb_find_socket(kctl, unit);
+       if (so == NULL)
                return EINVAL;
        
-       socket_lock(so, 1);
        avail = sbspace(&so->so_rcv);
        *space = (avail < 0) ? 0 : avail;
        socket_unlock(so, 1);
-
+       
        return 0;
 }
 
@@ -624,6 +628,9 @@ ctl_ctloutput(struct socket *so, struct sockopt *sopt)
                        socket_unlock(so, 0);
                        error = (*kctl->getopt)(kcb->kctl, kcb->unit, kcb->userdata, sopt->sopt_name, 
                                                data, &len);
+                       if (data != NULL && len > sopt->sopt_valsize)
+                               panic_plain("ctl_ctloutput: ctl %s returned len (%lu) > sopt_valsize (%lu)\n",
+                                       kcb->kctl->name, len, sopt->sopt_valsize);
                        socket_lock(so, 0);    
                        if (error == 0) {
                                if (data != NULL)
@@ -648,34 +655,38 @@ ctl_ioctl(__unused struct socket *so, u_long cmd, caddr_t data,
                /* get the number of controllers */
                case CTLIOCGCOUNT: {
                        struct kctl     *kctl;
-                       int n = 0;
+                       u_int32_t n = 0;
 
                        lck_mtx_lock(ctl_mtx);
                        TAILQ_FOREACH(kctl, &ctl_head, next)
                                n++;
                        lck_mtx_unlock(ctl_mtx);
-
-                       *(u_int32_t *)data = n;
+                       
+                       bcopy(&n, data, sizeof (n));
                        error = 0;
                        break;
                }
                case CTLIOCGINFO: {
-                       struct ctl_info *ctl_info = (struct ctl_info *)data;
+                       struct ctl_info ctl_info;
                        struct kctl     *kctl = 0;
-                       size_t name_len = strlen(ctl_info->ctl_name);
-                       
+                       size_t name_len;
+
+                       bcopy(data, &ctl_info, sizeof (ctl_info));
+                       name_len = strnlen(ctl_info.ctl_name, MAX_KCTL_NAME);
+
                        if (name_len == 0 || name_len + 1 > MAX_KCTL_NAME) {
                                error = EINVAL;
                                break;
                        }
                        lck_mtx_lock(ctl_mtx);
-                       kctl = ctl_find_by_name(ctl_info->ctl_name);
+                       kctl = ctl_find_by_name(ctl_info.ctl_name);
                        lck_mtx_unlock(ctl_mtx);
                        if (kctl == 0) {
                                error = ENOENT;
                                break;
                        }
-                       ctl_info->ctl_id = kctl->id;
+                       ctl_info.ctl_id = kctl->id;
+                       bcopy(&ctl_info, data, sizeof (ctl_info));
                        error = 0;
                        break;
                }
@@ -697,6 +708,7 @@ ctl_register(struct kern_ctl_reg *userkctl, kern_ctl_ref *kctlref)
        struct kctl     *kctl_next = NULL;
        u_int32_t               id = 1;
        size_t                  name_len;
+       int                             is_extended = 0;
        
        if (userkctl == NULL)   /* sanity check */
                return(EINVAL);
@@ -779,6 +791,9 @@ ctl_register(struct kern_ctl_reg *userkctl, kern_ctl_ref *kctlref)
                kctl->id = userkctl->ctl_id;
                kctl->reg_unit = userkctl->ctl_unit;
        }
+
+       is_extended = (userkctl->ctl_flags & CTL_FLAG_REG_EXTENDED);
+
        strlcpy(kctl->name, userkctl->ctl_name, MAX_KCTL_NAME);
        kctl->flags = userkctl->ctl_flags;
 
@@ -796,6 +811,9 @@ ctl_register(struct kern_ctl_reg *userkctl, kern_ctl_ref *kctlref)
        kctl->send = userkctl->ctl_send;
        kctl->setopt = userkctl->ctl_setopt;
        kctl->getopt = userkctl->ctl_getopt;
+       if (is_extended) {
+               kctl->rcvd = userkctl->ctl_rcvd;
+       }
        
        TAILQ_INIT(&kctl->kcb_head);
        
@@ -858,6 +876,46 @@ ctl_find_by_name(const char *name)
     return NULL;
 }
 
+u_int32_t
+ctl_id_by_name(const char *name)
+{
+       u_int32_t       ctl_id = 0;
+       
+       lck_mtx_lock(ctl_mtx);
+       struct kctl *kctl = ctl_find_by_name(name);
+       if (kctl) ctl_id = kctl->id;
+       lck_mtx_unlock(ctl_mtx);
+       
+       return ctl_id;
+}
+
+errno_t
+ctl_name_by_id(
+       u_int32_t id,
+       char    *out_name,
+       size_t  maxsize)
+{
+       int             found = 0;
+       
+       lck_mtx_lock(ctl_mtx);
+       struct kctl *kctl;
+    TAILQ_FOREACH(kctl, &ctl_head, next) {
+        if (kctl->id == id)
+            break;
+    }
+    
+    if (kctl && kctl->name)
+    {
+       if (maxsize > MAX_KCTL_NAME)
+               maxsize = MAX_KCTL_NAME;
+       strlcpy(out_name, kctl->name, maxsize);
+       found = 1;
+    }
+       lck_mtx_unlock(ctl_mtx);
+       
+       return found ? 0 : ENOENT;
+}
+
 /*
  * Must be called with global ctl_mtx lock taked
  *
@@ -885,21 +943,58 @@ kcb_find(struct kctl *kctl, u_int32_t unit)
     struct ctl_cb      *kcb;
 
     TAILQ_FOREACH(kcb, &kctl->kcb_head, next)
-        if ((kcb->unit == unit))
+        if (kcb->unit == unit)
             return kcb;
 
     return NULL;
 }
 
-/*
- * Must be called witout lock
- */
+static struct socket *
+kcb_find_socket(struct kctl *kctl, u_int32_t unit)
+{
+       struct socket *so = NULL;
+       
+       lck_mtx_lock(ctl_mtx);
+       struct ctl_cb   *kcb = kcb_find(kctl, unit);
+       if (kcb && kcb->kctl == kctl) {
+               so = kcb->so;
+               if (so) {
+                       kcb->usecount++;
+               }
+       }
+       lck_mtx_unlock(ctl_mtx);
+       
+       if (so == NULL) {
+               return NULL;
+       }
+       
+       socket_lock(so, 1);
+       
+       lck_mtx_lock(ctl_mtx);
+       if (kcb->kctl == NULL)
+       {
+               lck_mtx_unlock(ctl_mtx);
+               socket_unlock(so, 1);
+               so = NULL;
+               lck_mtx_lock(ctl_mtx);
+       }
+       kcb->usecount--;
+       if (kcb->usecount == 0)
+               wakeup((event_t)&kcb->usecount);
+       lck_mtx_unlock(ctl_mtx);
+       
+       return so;
+}
+
 static void 
 ctl_post_msg(u_int32_t event_code, u_int32_t id) 
 {
     struct ctl_event_data      ctl_ev_data;
     struct kev_msg             ev_msg;
     
+    lck_mtx_assert(ctl_mtx, LCK_MTX_ASSERT_NOTOWNED);
+   
+    bzero(&ev_msg, sizeof(struct kev_msg)); 
     ev_msg.vendor_code    = KEV_VENDOR_APPLE;
     
     ev_msg.kev_class      = KEV_SYSTEM_CLASS;