]> git.saurik.com Git - apple/xnu.git/blobdiff - bsd/netinet/cpu_in_cksum_gen.c
xnu-7195.101.1.tar.gz
[apple/xnu.git] / bsd / netinet / cpu_in_cksum_gen.c
index 2cdb63596e8b69ad8a2ccfe52e2669e6c02a1e03..e1cdf126ebb73b8740f41bee319de825660b6fdd 100644 (file)
@@ -108,27 +108,45 @@ uint32_t
 os_cpu_in_cksum(const void *data, uint32_t len, uint32_t initial_sum)
 {
        /*
-        * If data is 4-bytes aligned, length is multiple of 4-bytes,
-        * and the amount to checksum is small, this would be quicker;
-        * this is suitable for IPv4 header.
+        * If data is 4-bytes aligned (conditional), length is multiple
+        * of 4-bytes (required), and the amount to checksum is small,
+        * this would be quicker; this is suitable for IPv4/TCP header.
         */
-       if (IS_P2ALIGNED(data, sizeof(uint32_t)) &&
-           len <= 64 && (len & 3) == 0) {
+       if (
+#if !defined(__arm64__) && !defined(__x86_64__)
+               IS_P2ALIGNED(data, sizeof(uint32_t)) &&
+#endif /* !__arm64__ && !__x86_64__ */
+               len <= 64 && (len & 3) == 0) {
                uint8_t *p = __DECONST(uint8_t *, data);
                uint64_t sum = initial_sum;
 
-               if (PREDICT_TRUE(len == 20)) {  /* simple IPv4 header */
+               switch (len) {
+               case 20:                /* simple IPv4 or TCP header */
                        sum += *(uint32_t *)(void *)p;
                        sum += *(uint32_t *)(void *)(p + 4);
                        sum += *(uint32_t *)(void *)(p + 8);
                        sum += *(uint32_t *)(void *)(p + 12);
                        sum += *(uint32_t *)(void *)(p + 16);
-               } else {
+                       break;
+
+               case 32:                /* TCP header + timestamp option */
+                       sum += *(uint32_t *)(void *)p;
+                       sum += *(uint32_t *)(void *)(p + 4);
+                       sum += *(uint32_t *)(void *)(p + 8);
+                       sum += *(uint32_t *)(void *)(p + 12);
+                       sum += *(uint32_t *)(void *)(p + 16);
+                       sum += *(uint32_t *)(void *)(p + 20);
+                       sum += *(uint32_t *)(void *)(p + 24);
+                       sum += *(uint32_t *)(void *)(p + 28);
+                       break;
+
+               default:
                        while (len) {
                                sum += *(uint32_t *)(void *)p;
                                p += 4;
                                len -= 4;
                        }
+                       break;
                }
 
                /* fold 64-bit to 16-bit (deferred carries) */