]> git.saurik.com Git - apt.git/blobdiff - apt-pkg/contrib/srvrec.cc
reset HOME, USER(NAME), TMPDIR & SHELL in DropPrivileges
[apt.git] / apt-pkg / contrib / srvrec.cc
index 70247c2aec9e286355dff23526d19928e5a53248..327e59937eda49b4bfdaa66efae58f4ead07efbd 100644 (file)
 #include <netinet/in.h>
 #include <arpa/nameser.h>
 #include <resolv.h>
 #include <netinet/in.h>
 #include <arpa/nameser.h>
 #include <resolv.h>
+#include <time.h>
 
 #include <algorithm>
 
 #include <algorithm>
+#include <tuple>
 
 
-#include <apt-pkg/strutl.h>
+#include <apt-pkg/configuration.h>
 #include <apt-pkg/error.h>
 #include <apt-pkg/error.h>
+#include <apt-pkg/strutl.h>
+
+
 #include "srvrec.h"
 
 #include "srvrec.h"
 
+
+bool SrvRec::operator==(SrvRec const &other) const
+{
+   return (std::tie(target, priority, weight, port) ==
+           std::tie(other.target, other.priority, other.weight, other.port));
+}
+
 bool GetSrvRecords(std::string host, int port, std::vector<SrvRec> &Result)
 {
    std::string target;
 bool GetSrvRecords(std::string host, int port, std::vector<SrvRec> &Result)
 {
    std::string target;
-   struct servent *s_ent = getservbyport(htons(port), "tcp");
-   if (s_ent == NULL)
+   int res;
+   struct servent s_ent_buf;
+   struct servent *s_ent = nullptr;
+   std::vector<char> buf(1024);
+
+   res = getservbyport_r(htons(port), "tcp", &s_ent_buf, buf.data(), buf.size(), &s_ent);
+   if (res != 0 || s_ent == nullptr)
       return false;
 
    strprintf(target, "_%s._tcp.%s", s_ent->s_name, host.c_str());
       return false;
 
    strprintf(target, "_%s._tcp.%s", s_ent->s_name, host.c_str());
@@ -41,6 +58,8 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
       return _error->Errno("res_init", "Failed to init resolver");
 
    answer_len = res_query(name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
       return _error->Errno("res_init", "Failed to init resolver");
 
    answer_len = res_query(name.c_str(), C_IN, T_SRV, answer, sizeof(answer));
+   if (answer_len == -1)
+      return false;
    if (answer_len < (int)sizeof(HEADER))
       return _error->Warning("Not enough data from res_query (%i)", answer_len);
 
    if (answer_len < (int)sizeof(HEADER))
       return _error->Warning("Not enough data from res_query (%i)", answer_len);
 
@@ -61,10 +80,9 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
    unsigned char *pt = answer+sizeof(HEADER)+compressed_name_len+QFIXEDSZ;
    while ((int)Result.size() < answer_count && pt < answer+answer_len)
    {
    unsigned char *pt = answer+sizeof(HEADER)+compressed_name_len+QFIXEDSZ;
    while ((int)Result.size() < answer_count && pt < answer+answer_len)
    {
-      SrvRec rec;
       u_int16_t type, klass, priority, weight, port, dlen;
       char buf[MAXDNAME];
       u_int16_t type, klass, priority, weight, port, dlen;
       char buf[MAXDNAME];
-      
+
       compressed_name_len = dn_skipname(pt, answer+answer_len);
       if (compressed_name_len < 0)
          return _error->Warning("dn_skipname failed (2): %i",
       compressed_name_len = dn_skipname(pt, answer+answer_len);
       if (compressed_name_len < 0)
          return _error->Warning("dn_skipname failed (2): %i",
@@ -74,15 +92,15 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
          return _error->Warning("packet too short");
 
       // extract the data out of the result buffer
          return _error->Warning("packet too short");
 
       // extract the data out of the result buffer
-      #define extract_u16(target, p) target = *p++ << 8; target |= *p++; 
+      #define extract_u16(target, p) target = *p++ << 8; target |= *p++;
 
       extract_u16(type, pt);
       if(type != T_SRV)
 
       extract_u16(type, pt);
       if(type != T_SRV)
-         return _error->Warning("Unexpected type excepted %x != %x", 
+         return _error->Warning("Unexpected type excepted %x != %x",
                                 T_SRV, type);
       extract_u16(klass, pt);
       if(klass != C_IN)
                                 T_SRV, type);
       extract_u16(klass, pt);
       if(klass != C_IN)
-         return _error->Warning("Unexpected class excepted %x != %x", 
+         return _error->Warning("Unexpected class excepted %x != %x",
                                 C_IN, klass);
       pt += 4;  // ttl
       extract_u16(dlen, pt);
                                 C_IN, klass);
       pt += 4;  // ttl
       extract_u16(dlen, pt);
@@ -98,11 +116,7 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
       pt += compressed_name_len;
 
       // add it to our class
       pt += compressed_name_len;
 
       // add it to our class
-      rec.priority = priority;
-      rec.weight = weight;
-      rec.port = port;
-      rec.target = buf;
-      Result.push_back(rec);
+      Result.emplace_back(buf, priority, weight, port);
    }
 
    // implement load balancing as specified in RFC-2782
    }
 
    // implement load balancing as specified in RFC-2782
@@ -110,6 +124,29 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
    // sort them by priority
    std::stable_sort(Result.begin(), Result.end());
 
    // sort them by priority
    std::stable_sort(Result.begin(), Result.end());
 
+   for(std::vector<SrvRec>::iterator I = Result.begin();
+      I != Result.end(); ++I)
+   {
+      if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
+      {
+         std::cerr << "SrvRecs: got " << I->target
+                   << " prio: " << I->priority
+                   << " weight: " << I->weight
+                   << std::endl;
+      }
+   }
+
+   return true;
+}
+
+SrvRec PopFromSrvRecs(std::vector<SrvRec> &Recs)
+{
+   // FIXME: instead of the simplistic shuffle below use the algorithm
+   //        described in rfc2782 (with weights)
+   //        and figure out how the weights need to be adjusted if
+   //        a host refuses connections
+
+#if 0  // all code below is only needed for the weight adjusted selection 
    // assign random number ranges
    int prev_weight = 0;
    int prev_priority = 0;
    // assign random number ranges
    int prev_weight = 0;
    int prev_priority = 0;
@@ -122,6 +159,12 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
       I->random_number_range_end = prev_weight + I->weight;
       prev_weight = I->random_number_range_end;
       prev_priority = I->priority;
       I->random_number_range_end = prev_weight + I->weight;
       prev_weight = I->random_number_range_end;
       prev_priority = I->priority;
+
+      if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
+         std::cerr << "SrvRecs: got " << I->target
+                   << " prio: " << I->priority
+                   << " weight: " << I->weight
+                   << std::endl;
    }
 
    // go over the code in reverse order and note the max random range
    }
 
    // go over the code in reverse order and note the max random range
@@ -134,8 +177,20 @@ bool GetSrvRecords(std::string name, std::vector<SrvRec> &Result)
          max = I->random_number_range_end;
       I->random_number_range_max = max;
    }
          max = I->random_number_range_end;
       I->random_number_range_max = max;
    }
+#endif
 
 
-   // now shuffle 
+   // shuffle in a very simplistic way for now (equal weights)
+   std::vector<SrvRec>::iterator I = Recs.begin();
+   std::vector<SrvRec>::iterator const J = std::find_if(Recs.begin(), Recs.end(),
+        [&I](SrvRec const &J) { return I->priority != J.priority; });
 
 
-   return true;
+   // clock seems random enough.
+   I += std::max(static_cast<clock_t>(0), clock()) % std::distance(I, J);
+   SrvRec const selected = std::move(*I);
+   Recs.erase(I);
+
+   if (_config->FindB("Debug::Acquire::SrvRecs", false) == true)
+      std::cerr << "PopFromSrvRecs: selecting " << selected.target << std::endl;
+
+   return selected;
 }
 }