]> git.saurik.com Git - apple/libinfo.git/blobdiff - netinfo.subproj/ni_glue.c
Libinfo-89.tar.gz
[apple/libinfo.git] / netinfo.subproj / ni_glue.c
index 0a22e74740e32839eac98b91dcf06668f6b63c65..751cd9306a9d0cfb02a33bcd75d971e0bbeb24ba 100644 (file)
@@ -34,6 +34,9 @@
 #include <net/if.h>
 #include <ctype.h>
 #include "clib.h"
+#include "sys_interfaces.h"
+
+#define LOCAL_PORT 1033
 
 #define NI_TIMEOUT_SHORT 5     /* 5 second timeout for transactions */
 #define NI_TIMEOUT_LONG 60     /* 60 second timeout for writes */
@@ -44,7 +47,7 @@
 
 /* Hack for determining if an IP address is a broadcast address. -GRS */
 /* Note that addr is network byte order (big endian) - BKM */
-
 #define IS_BROADCASTADDR(addr) (((unsigned char *) &addr)[0] == 0xFF)
 
 #ifndef INADDR_LOOPBACK
@@ -117,46 +120,6 @@ getmyport(
 }
 
 
-/*
- * Is the NetInfo binder running?
- */
-static int
-nibind_up(
-       ni_private *ni
-       )
-{
-       int sock;
-       struct sockaddr_in sin;
-       int res;
-
-       sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
-       if (sock < 0) {
-               return (0);
-       }
-       sin.sin_family = AF_INET;
-       sin.sin_port = htons(PMAPPORT);
-       sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
-       bzero(sin.sin_zero, sizeof(sin.sin_zero));
-       res = connect(sock, (struct sockaddr *)&sin, sizeof(sin));
-       close(sock);
-       if (res != 0) {
-               return (0);
-       }
-       sin.sin_port = htons(pmap_getport(&sin, NIBIND_PROG, NIBIND_VERS, 
-                                         IPPROTO_TCP));
-       if (sin.sin_port == 0) {
-               return (0);
-       }
-       sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
-       if (sock < 0) {
-               return (0);
-       }
-       res = connect(sock, (struct sockaddr *)&sin, sizeof(sin));
-       close(sock);
-       return (res == 0);
-}
-       
-
 static void
 createauth(
           ni_private *ni
@@ -203,9 +166,7 @@ ni_settimeout(
  * Connect to a given address/tag
  */
 static int
-connectit(
-         ni_private *ni
-         )
+connectit(ni_private *ni)
 {
        struct sockaddr_in sin;
        int sock;
@@ -213,54 +174,82 @@ connectit(
        struct timeval tv;
        enum clnt_stat stat;
        nibind_getregister_res res;
-       
+
+       sock = -1;
        bzero(&sin, sizeof(sin));
        sin.sin_port = 0;
        sin.sin_family = AF_INET;
-       sin.sin_addr = ni->addrs[0];
-       ni_settimeout(ni, ni->rtv_sec == 0 ? NI_TIMEOUT_SHORT : ni->rtv_sec);
-       fixtimeout(&tv, ni->tv_sec, NI_TRIES);
-       sock = socket_open(&sin, NIBIND_PROG, NIBIND_VERS, ni->tv_sec, 
-                          NI_TRIES, IPPROTO_UDP);
-       if (sock < 0) {
-               return (0);
-       }
-       cl = clntudp_create(&sin, NIBIND_PROG, NIBIND_VERS, tv,
-                           &sock);
-       if (cl == NULL) {
-               close(sock);
-               return (0);
-       }
+       
        tv.tv_sec = ni->rtv_sec == 0 ? NI_TIMEOUT_SHORT : ni->rtv_sec;
        tv.tv_usec = 0;
-       stat = clnt_call(cl, NIBIND_GETREGISTER, xdr_ni_name, &ni->tags[0],
-                        xdr_nibind_getregister_res, &res, tv);
-       clnt_destroy(cl);
-       close(sock);
-       if (stat != RPC_SUCCESS || res.status != NI_OK) {
-               return (0);
-       }
        
+       ni_settimeout(ni, tv.tv_sec);
+       fixtimeout(&tv, ni->tv_sec, NI_TRIES);
+
        /*
-        * Found the address, now connect to it. 
+        * If connecting to local domain, try using the "well-known" port first.
         */
-       sin.sin_port = htons(res.nibind_getregister_res_u.addrs.tcp_port);
-       sock = socket_open(&sin, NI_PROG, NI_VERS, ni->tv_sec, NI_TRIES,
-                          IPPROTO_TCP);
-       if (sock < 0) {
-               return (0);
+       if (!strcmp(ni->tags[0], "local"))
+       {
+               interface_list_t *ilist;
+
+               ilist = sys_interfaces();
+               if (sys_is_my_address(ilist, &ni->addrs[0]))
+               {
+                       sin.sin_port = htons(LOCAL_PORT);
+                       sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+                       sock = socket_open(&sin, NI_PROG, NI_VERS, ni->tv_sec, NI_TRIES, IPPROTO_TCP);
+               }
+               sys_interfaces_release(ilist);
+       }
+
+       /*
+        * If connecting to a domain other than the local domain,
+        * or if connection to local didn't work with local's well-known port,
+        * then go through portmap & nibindd to find the port and connect.
+        */
+       if (sock < 0)
+       {
+               sin.sin_port = 0;
+               sin.sin_addr = ni->addrs[0];
+
+               sock = socket_open(&sin, NIBIND_PROG, NIBIND_VERS, ni->tv_sec, NI_TRIES, IPPROTO_UDP);
+               if (sock < 0) return (0);
+
+               cl = clntudp_create(&sin, NIBIND_PROG, NIBIND_VERS, tv, &sock);
+               if (cl == NULL)
+               {
+                       close(sock);
+                       return (0);
+               }
+
+               tv.tv_sec = ni->rtv_sec == 0 ? NI_TIMEOUT_SHORT : ni->rtv_sec;
+               tv.tv_usec = 0;
+
+               stat = clnt_call(cl, NIBIND_GETREGISTER, xdr_ni_name, &ni->tags[0], xdr_nibind_getregister_res, &res, tv);
+               clnt_destroy(cl);
+               close(sock);
+               if (stat != RPC_SUCCESS || res.status != NI_OK) return (0);
+       
+               sin.sin_port = htons(res.nibind_getregister_res_u.addrs.tcp_port);
+               sock = socket_open(&sin, NI_PROG, NI_VERS, ni->tv_sec, NI_TRIES, IPPROTO_TCP);
        }
+
+       if (sock < 0) return (0);
+
        cl = clnttcp_create(&sin, NI_PROG, NI_VERS, &sock, 0, 0);
-       if (cl == NULL) {
+       if (cl == NULL)
+       {
                close(sock);
                return (0);
        }
+
        clnt_control(cl, CLSET_TIMEOUT, &tv);
        ni->tc = cl;
        ni->tsock = sock;
        ni->tport = getmyport(sock);
        createauth(ni);
-       (void) fcntl(ni->tsock, F_SETFD, 1);
+       fcntl(ni->tsock, F_SETFD, 1);
        return (1);
 }
 
@@ -305,118 +294,6 @@ ni_needwrite(
 }
 
 
-static unsigned long
-sys_netmask(void)
-{
-       struct ifconf ifc;
-       struct ifreq *ifr;
-       char buf[1024]; /* XXX */
-       int i, len, ifreq_size, offset, sockaddr_size, size;
-       int sock;
-       struct sockaddr_in *sin;
-       unsigned long n_addr;
-
-       sock = socket(AF_INET, SOCK_DGRAM, 0);
-
-       if (sock < 0) return (htonl(IN_CLASSA_NET));
-
-       ifc.ifc_len = sizeof(buf);
-       ifc.ifc_buf = buf;
-
-       if (ioctl(sock, SIOCGIFCONF, (char *)&ifc) < 0)
-       {
-               close(sock);
-               return (htonl(IN_CLASSA_NET));
-       }
-
-       ifreq_size = sizeof(struct ifreq);
-       sockaddr_size = sizeof(struct sockaddr);
-
-       offset = 0;
-       len = ifc.ifc_len / ifreq_size;
-       for (i = 0; i < len; i++)
-       {
-               ifr = (struct ifreq *)(ifc.ifc_buf + offset);
-               offset += IFNAMSIZ;
-               offset += sockaddr_size;
-
-               size = ifr->ifr_addr.sa_len;
-               if (size > sockaddr_size) offset += (size - sockaddr_size);
-
-               if (ifr->ifr_addr.sa_family != AF_INET) continue;
-               if (ioctl(sock, SIOCGIFFLAGS, (char *)ifr) < 0) continue;
-
-               sin = (struct sockaddr_in *)&ifr->ifr_addr;
-               if ((ifr->ifr_flags & IFF_UP) &&
-                       !(ifr->ifr_flags & IFF_LOOPBACK) &&
-                       (sin->sin_addr.s_addr != 0))
-               {
-                       ioctl(sock, SIOCGIFNETMASK, (char *)ifr);
-                       n_addr = ((struct sockaddr_in *)&(ifr->ifr_addr))->sin_addr.s_addr;
-                       close(sock);
-                       return (n_addr);
-               }
-       }
-
-       close(sock);
-       return (htonl(IN_CLASSA_NET));
-}
-
-
-static unsigned long
-sys_address(void)
-{
-       struct ifconf ifc;
-       struct ifreq *ifr;
-       char buf[1024]; /* XXX */
-       int i, len, ifreq_size, offset, sockaddr_size, size;
-       int sock;
-
-       sock = socket(AF_INET, SOCK_DGRAM, 0);
-
-       if (sock < 0) 
-       {
-               return (htonl(INADDR_LOOPBACK));
-       }
-
-       ifc.ifc_len = sizeof(buf);
-       ifc.ifc_buf = buf;
-
-       if (ioctl(sock, SIOCGIFCONF, (char *)&ifc) < 0)
-       {
-               close(sock);
-               return (htonl(INADDR_LOOPBACK));
-       }
-
-       ifreq_size = sizeof(struct ifreq);
-       sockaddr_size = sizeof(struct sockaddr);
-
-       offset = 0;
-       len = ifc.ifc_len / ifreq_size;
-       for (i = 0; i < len; i++)
-       {
-               ifr = (struct ifreq *)(ifc.ifc_buf + offset);
-               offset += IFNAMSIZ;
-               offset += sockaddr_size;
-
-               size = ifr->ifr_addr.sa_len;
-               if (size > sockaddr_size) offset += (size - sockaddr_size);
-
-               if (ifr->ifr_addr.sa_family != AF_INET) continue;
-               if (ioctl(sock, SIOCGIFFLAGS, ifr) < 0) continue;
-
-               if ((ifr->ifr_flags & IFF_UP) && (!(ifr->ifr_flags & IFF_LOOPBACK)))
-               {
-                       close(sock);
-                       return (((struct sockaddr_in *)&(ifr->ifr_addr))->sin_addr.s_addr);
-               }
-       }
-
-       close(sock);
-       return (htonl(INADDR_LOOPBACK));
-}
-
-
 /*
  * Returns a client handle to the NetInfo server, if it's running
  */
@@ -425,26 +302,30 @@ connectlocal(ni_private *ni)
 {
        int printed = 0;
 
-       if (!nibind_up(ni)) {
-               return (0);
-       }
        ni->naddrs = 1;
        ni->addrs = (struct in_addr *)malloc(sizeof(struct in_addr));
        ni->addrs[0].s_addr = htonl(INADDR_LOOPBACK);
        ni->tags = (ni_name *)malloc(sizeof(ni_name));
        ni->tags[0] = ni_name_dup("local");
        ni->whichwrite = 0;
-       while (!connectit(ni)) {
-               if (!printed) {
+
+       while (!connectit(ni))
+       {
+               if (!printed)
+               {
                        syslog(LOG_ERR, "NetInfo timeout connecting to local domain, sleeping");
                        printed++;
                }
+
                sleep(NI_SLEEPTIME);
                /* wait forever */
        }
-       if (printed) {
+
+       if (printed)
+       {
                syslog(LOG_ERR, "NetInfo connection to local domain waking");
        }
+
        return (1);
 }
 
@@ -546,6 +427,8 @@ ni_swap(
        struct in_addr tmp_addr;
        ni_name tmp_tag;
 
+       if (a == b) return;
+
        tmp_addr = ni->addrs[a];
        tmp_tag = ni->tags[a];
 
@@ -585,6 +468,51 @@ eachresult(
 }
 
 
+/*
+ * shuffle addresses
+ */
+static void
+shuffle(ni_private *ni)
+{
+       int *shuffle;
+       int i, j;
+       int rfd;
+
+       if (ni->naddrs <= 1) return;
+
+       rfd = open("/dev/random", O_RDONLY, 0);
+       shuffle = (int *)malloc(ni->naddrs * sizeof(int));
+       for (i = 0; i < ni->naddrs; i++) shuffle[i] = i;
+       for (i = 0, j = ni->naddrs; j > 0; i++, j--) {
+               unsigned int rEnt;
+               long rVal;
+               int tEnt;
+
+               /* get a random number */
+               if ((rfd < 0) ||
+                   (read(rfd, &rVal, sizeof(rVal)) != sizeof(rVal))) {
+                       /* if we could not read from /dev/random */
+                       static int initialized = 0;
+                       if (!initialized)
+                       {
+                               srandom(gethostid() ^ time(NULL));
+                               initialized++;
+                       }
+                       rVal = random();
+               }
+
+               rEnt = (unsigned int)rVal % j;  /* pick one of the remaining entries */
+               tEnt = shuffle[rEnt];           /* grab the random entry */
+               shuffle[rEnt] = shuffle[j-1];   /* the last entry moves to the random slot */ 
+               shuffle[j-1]  = tEnt;           /* the last slot gets the random entry */
+               ni_swap(ni, rEnt, j-1);         /* and swap the actual NI addresses */
+       }
+       free(shuffle);
+       if (rfd > 0) (void)close(rfd);
+       return;
+}
+
+
 static int
 rebind(
        ni_private *ni
@@ -596,9 +524,7 @@ rebind(
        int printed = 0;
        int nlocal;
        int nnetwork;
-       unsigned long myaddr;
-       unsigned long mynetmask;
-       unsigned long mynetwork;
+       interface_list_t *ilist;
        int i;
 
        if (ni->naddrs == 1) {
@@ -614,17 +540,19 @@ rebind(
         * all other servers are next
         */
 
-       myaddr = sys_address();
-       mynetmask = sys_netmask();
-       mynetwork = myaddr & mynetmask;
+       ilist = sys_interfaces();
+
+       /*
+        * shuffle addresses
+        */
+       shuffle(ni);
 
        /*
         * move local servers to the head of the list
         */
        nlocal = 0;
        for (i = nlocal; i < ni->naddrs; i++) {
-               if ((ni->addrs[i].s_addr == myaddr) ||
-                       (ni->addrs[i].s_addr == htonl(INADDR_LOOPBACK)))
+               if (sys_is_my_address(ilist, &ni->addrs[i]))
                {
                        ni_swap(ni, nlocal, i);
                        nlocal++;
@@ -636,7 +564,7 @@ rebind(
         */
        nnetwork = nlocal;
        for (i = nnetwork; i < ni->naddrs; i++) {
-               if (((ni->addrs[i].s_addr & mynetmask) == mynetwork) ||
+               if (sys_is_my_network(ilist, &ni->addrs[i]) ||
                        IS_BROADCASTADDR(ni->addrs[i].s_addr))
                {
                        ni_swap(ni, nnetwork, i);
@@ -644,6 +572,8 @@ rebind(
                }
        }
 
+       sys_interfaces_release(ilist);
+
        stuff.ni = ni;
        for (;;) {
                /*
@@ -2220,7 +2150,8 @@ socket_open(
            )
 {
        int sock;
-       
+       int reuse = 1;
+
        /*
         * If no port number given ask the pmap for one
         */
@@ -2240,6 +2171,7 @@ socket_open(
                return (-1);
        }
        (void)bindresvport(sock, (struct sockaddr_in *)0);
+       setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &reuse, sizeof(int));
        if (proto == IPPROTO_TCP) {
                if (connect(sock, (struct sockaddr *)raddr,
                            sizeof(*raddr)) < 0) {