]> git.saurik.com Git - apple/security.git/blobdiff - OSX/libsecurity_ssl/lib/sslContext.c
Security-59306.61.1.tar.gz
[apple/security.git] / OSX / libsecurity_ssl / lib / sslContext.c
index 21dd4628a3e7c9ea47687f005a2fe2e5194e8871..fc1826feb498ef885f4980884d2c4e133d3f3cb4 100644 (file)
@@ -85,6 +85,7 @@ Boolean sslIsSessionActive(const SSLContext *ctx)
                case SSL_HdskStateUninit:
                case SSL_HdskStateGracefulClose:
                case SSL_HdskStateErrorClose:
+        case SSL_HdskStateOutOfBandError:
                        return false;
                default:
                        return true;
@@ -151,6 +152,8 @@ CFTypeRef SSLPreferencesCopyValue(CFStringRef key, CFPropertyListRef managed_pre
 
     if(!value && managed_prefs) {
         value =  CFDictionaryGetValue(managed_prefs, key);
+        if (value)
+            CFRetain(value);
     }
 
     return value;
@@ -419,6 +422,7 @@ void SSLContextDestroy(CFTypeRef arg)
     SSLFreeBuffer(&ctx->peerID);
     SSLFreeBuffer(&ctx->resumableSession);
     SSLFreeBuffer(&ctx->receivedDataBuffer);
+    SSLFreeBuffer(&ctx->contextConfigurationBuffer);
 
     CFReleaseSafe(ctx->acceptableCAs);
 #if !TARGET_OS_IPHONE
@@ -467,6 +471,7 @@ SSLGetSessionState                  (SSLContextRef          context,
                        break;
                case SSL_HdskStateErrorClose:
                case SSL_HdskStateNoNotifyClose:
+        case SSL_HdskStateOutOfBandError:
                        rtnState = kSSLAborted;
                        break;
                case SSL_HdskStateReady:
@@ -488,13 +493,15 @@ SSLSetSessionOption                       (SSLContextRef          context,
                                                         SSLSessionOption       option,
                                                         Boolean                        value)
 {
-    if(context == NULL) {
-       return errSecParam;
+    if (context == NULL) {
+        return errSecParam;
     }
-    if(sslIsSessionActive(context)) {
-       /* can't do this with an active session */
-       return errSecBadReq;
+
+    if (sslIsSessionActive(context)) {
+        /* can't do this with an active session */
+        return errSecBadReq;
     }
+
     switch(option) {
         case kSSLSessionOptionBreakOnServerAuth:
             context->breakOnServerAuth = value;
@@ -509,11 +516,13 @@ SSLSetSessionOption                       (SSLContextRef          context,
             break;
         case kSSLSessionOptionSendOneByteRecord:
             /* Only call the record layer function if the value changed */
-            if(value != context->oneByteRecordEnable)
+            if (value != context->oneByteRecordEnable) {
                 context->recFuncs->setOption(context->recCtx, kSSLRecordOptionSendOneByteRecord, value);
+            }
             context->oneByteRecordEnable = value;
             break;
         case kSSLSessionOptionFalseStart:
+            tls_handshake_set_false_start(context->hdsk, value);
             context->falseStartEnabled = value;
             break;
         case kSSLSessionOptionFallback:
@@ -525,9 +534,15 @@ SSLSetSessionOption                        (SSLContextRef          context,
             break;
         case kSSLSessionOptionAllowServerIdentityChange:
             tls_handshake_set_server_identity_change(context->hdsk, value);
+            context->allowServerIdentityChange = true;
             break;
         case kSSLSessionOptionAllowRenegotiation:
             tls_handshake_set_renegotiation(context->hdsk, value);
+            context->allowRenegotiation = true;
+            break;
+        case kSSLSessionOptionEnableSessionTickets:
+            tls_handshake_set_session_ticket_enabled(context->hdsk, value);
+            context->enableSessionTickets = true;
             break;
         default:
             return errSecParam;
@@ -716,25 +731,135 @@ SSLSetALPNData(SSLContextRef      context,
 
 const void *
 SSLGetALPNData(SSLContextRef      context,
-              size_t                           *length)
+               size_t             *length)
 {
-    if (context == NULL || length == NULL)
+    if (context == NULL || length == NULL) {
         return NULL;
+    }
+
+    const tls_buffer *alpnData = tls_handshake_get_peer_alpn_data(context->hdsk);
+
+    if (alpnData) {
+        *length = alpnData->length;
+        return alpnData->data;
+    } else {
+      return NULL;
+    }
+}
+
+OSStatus
+SSLSetALPNProtocols(SSLContextRef      context,
+                    CFArrayRef         protocols)
+{
+    if (context == NULL || protocols == NULL || CFArrayGetCount(protocols) == 0) {
+        return errSecParam;
+    }
+
+    // Per RFC 7301, the protocol name can be at most 32B long.
+    const int maxBufferLength = 32;
+
+    // Append each element in the array to a mutable buffer
+    CFMutableDataRef alpnData = CFDataCreateMutable(NULL, 0);
+    CFArrayForEach(protocols, ^(const void *value) {
+        CFStringRef protocolString = (CFStringRef) value;
+        uint8_t len = CFStringGetLength(protocolString);
+        if (len <= maxBufferLength) {
+            char stringBytes[maxBufferLength];
+            if (CFStringGetCString(protocolString, stringBytes, maxBufferLength, kCFStringEncodingASCII)) {
+                CFDataAppendBytes(alpnData, (const UInt8 *) &len, sizeof(len));
+                CFDataAppendBytes(alpnData, (const UInt8 *) stringBytes, len);
+            }
+        }
+    });
+
+    // Length check
+    if (CFDataGetLength(alpnData) > 255) {
+        CFRelease(alpnData);
+        return errSecParam;
+    }
 
-    const tls_buffer *alpn_data;
+    // Pass the buffer down to coreTLS
+    tls_buffer payload;
+    payload.data = (uint8_t *) CFDataGetBytePtr(alpnData);
+    payload.length = CFDataGetLength(alpnData);
+    int success = tls_handshake_set_alpn_data(context->hdsk, payload);
 
-    alpn_data = tls_handshake_get_peer_alpn_data(context->hdsk);
+    // Free up memory and return
+    CFRelease(alpnData);
 
-    if(alpn_data) {
-        *length = alpn_data->length;
-        return alpn_data->data;
+    return success;
+}
+
+OSStatus
+SSLCopyALPNProtocols(SSLContextRef      context,
+                     CFArrayRef         *protocolArray)
+{
+    if (context == NULL || protocolArray == NULL) {
+        return errSecParam;
+    }
+
+    CFMutableArrayRef array = CFArrayCreateMutableForCFTypes(NULL);
+
+    const tls_buffer *alpnData = tls_handshake_get_peer_alpn_data(context->hdsk);
+    if (alpnData) {
+        size_t offset = 0;
+
+        // Extract each encoded parameter, wrap it in a CFStringRef, and append it to the running list
+        while (offset < alpnData->length) {
+            char length = alpnData->data[offset];
+            offset++;
+
+            // Make sure we don't exceed the buffer bounds
+            if (offset + length > alpnData->length) {
+                CFReleaseNull(array);
+                *protocolArray = NULL;
+                return errSecParam;
+            }
+
+            CFStringRef protocol = CFStringCreateWithBytes(NULL, alpnData->data + offset, length, kCFStringEncodingASCII, false);
+            offset += length;
+            CFArrayAppendValue(array, protocol);
+            CFReleaseNull(protocol);
+
+            // Sanity check
+            if (offset > alpnData->length) {
+                CFReleaseNull(array);
+                *protocolArray = NULL;
+                return errSecParam;
+            }
+        }
+
+        *protocolArray = array;
+        return errSecSuccess;
     } else {
-        return NULL;
+        CFReleaseNull(array);
+        *protocolArray = NULL;
+        return errSecParam;
     }
 }
 
 // ALPN end
 
+// OCSP response begin
+
+OSStatus
+SSLSetOCSPResponse(SSLContextRef      context,
+                   CFDataRef          response)
+{
+    if (context == NULL || response == NULL) {
+        return errSecParam;
+    }
+
+    tls_buffer responsePayload;
+    responsePayload.data = (uint8_t *) CFDataGetBytePtr(response);
+    responsePayload.length = CFDataGetLength(response);
+
+    int success = tls_handshake_set_ocsp_response(context->hdsk, &responsePayload);
+    return success;
+}
+
+// OCSP response end
+
 OSStatus
 SSLSetConnection                       (SSLContextRef          ctx,
                                                         SSLConnectionRef       connection)
@@ -1070,12 +1195,12 @@ SSLSetProtocolVersionMax  (SSLContextRef      ctx,
         if (version > MINIMUM_DATAGRAM_VERSION ||
             version < MAXIMUM_DATAGRAM_VERSION)
             return errSSLIllegalParam;
-        if (version > ctx->minProtocolVersion)
+        if (version > (SSLProtocolVersion)ctx->minProtocolVersion)
             ctx->minProtocolVersion = version;
     } else {
         if (version < MINIMUM_STREAM_VERSION || version > MAXIMUM_STREAM_VERSION)
             return errSSLIllegalParam;
-        if (version < ctx->minProtocolVersion)
+        if (version < (SSLProtocolVersion)ctx->minProtocolVersion)
             ctx->minProtocolVersion = version;
     }
     ctx->maxProtocolVersion = version;
@@ -1096,6 +1221,49 @@ SSLGetProtocolVersionMax  (SSLContextRef      ctx,
     return errSecSuccess;
 }
 
+tls_protocol_version
+_SSLProtocolVersionToWireFormatValue   (SSLProtocol protocol)
+{
+    switch (protocol) {
+        case kSSLProtocol3: {
+            return tls_protocol_version_SSL_3;
+        }
+        case kTLSProtocol1: {
+            return tls_protocol_version_TLS_1_0;
+        }
+        case kTLSProtocol11: {
+            return tls_protocol_version_TLS_1_1;
+        }
+        case kTLSProtocol12: {
+            return tls_protocol_version_TLS_1_2;
+        }
+        case kTLSProtocol13: {
+            return tls_protocol_version_TLS_1_3;
+        }
+        case kTLSProtocolMaxSupported: {
+            return tls_protocol_version_TLS_1_3;
+        }
+        case kDTLSProtocol1: {
+            return tls_protocol_version_DTLS_1_0;
+        }
+        case kDTLSProtocol12: {
+            return tls_protocol_version_DTLS_1_2;
+        }
+        case kSSLProtocolUnknown: {
+            return tls_protocol_version_Undertermined;
+        }
+        case kSSLProtocol2:
+        case kSSLProtocol3Only:
+        case kTLSProtocol1Only:
+        case kSSLProtocolAll: {
+            sslErrorLog("SSLProtocol %d is deprecated. Setting to the default value (%d)", protocol, tls_protocol_version_Undertermined);
+            return tls_protocol_version_Undertermined;
+        }
+    }
+
+    return tls_protocol_version_Undertermined;
+}
+
 #define max(x,y) ((x)<(y)?(y):(x))
 
 OSStatus
@@ -1124,12 +1292,12 @@ SSLSetProtocolVersionEnabled(SSLContextRef     ctx,
                        if (version < MINIMUM_STREAM_VERSION || version > MAXIMUM_STREAM_VERSION) {
                                return errSecParam;
                        }
-            if (version > ctx->maxProtocolVersion) {
+            if (version > (SSLProtocolVersion)ctx->maxProtocolVersion) {
                 ctx->maxProtocolVersion = version;
                 if (ctx->minProtocolVersion == SSL_Version_Undetermined)
                     ctx->minProtocolVersion = version;
             }
-            if (version < ctx->minProtocolVersion) {
+            if (version < (SSLProtocolVersion)ctx->minProtocolVersion) {
                 ctx->minProtocolVersion = version;
             }
         } else {
@@ -1159,7 +1327,7 @@ SSLSetProtocolVersionEnabled(SSLContextRef     ctx,
                                        nextVersion = SSL_Version_Undetermined;
                                        break;
                        }
-                       ctx->minProtocolVersion = max(ctx->minProtocolVersion, nextVersion);
+                       ctx->minProtocolVersion = (tls_protocol_version)max((SSLProtocolVersion)ctx->minProtocolVersion, nextVersion);
                        if (ctx->minProtocolVersion > ctx->maxProtocolVersion) {
                                ctx->minProtocolVersion = SSL_Version_Undetermined;
                                ctx->maxProtocolVersion = SSL_Version_Undetermined;
@@ -1193,8 +1361,8 @@ SSLGetProtocolVersionEnabled(SSLContextRef                ctx,
         case kTLSProtocol12:
         {
             SSLProtocolVersion version = SSLProtocolToProtocolVersion(protocol);
-                       *enable = (ctx->minProtocolVersion <= version
-                       && ctx->maxProtocolVersion >= version);
+                       *enable = ((SSLProtocolVersion)ctx->minProtocolVersion <= version
+                       && (SSLProtocolVersion)ctx->maxProtocolVersion >= version);
                        break;
         }
                case kSSLProtocolAll:
@@ -1467,7 +1635,7 @@ SSLCopyTrustedRoots                       (SSLContextRef          ctx,
                CFRetain(ctx->trustedCerts);
                return errSecSuccess;
        }
-#if (TARGET_OS_MAC && !(TARGET_OS_EMBEDDED || TARGET_OS_IPHONE))
+#if TARGET_OS_OSX
        /* use default system roots */
     return sslDefaultSystemRoots(ctx, trustedRoots);
 #else
@@ -2350,8 +2518,11 @@ OSStatus SSLGetECDSACurves(
        if(*numCurves < ctx->ecdhNumCurves) {
                return errSecParam;
        }
-       memmove(namedCurves, ctx->ecdhCurves,
-               (ctx->ecdhNumCurves * sizeof(SSL_ECDSA_NamedCurve)));
+       static_assert(sizeof(*namedCurves) >= sizeof(*(ctx->ecdhCurves)),
+               "SSL_ECDSA_NamedCurve must be large enough for SSLContext ecdhCurves.");
+       for (unsigned i = 0; i < ctx->ecdhNumCurves; i++) {
+               namedCurves[i] = ctx->ecdhCurves[i];
+       }
        *numCurves = ctx->ecdhNumCurves;
        return errSecSuccess;
 }
@@ -2372,20 +2543,26 @@ OSStatus SSLSetECDSACurves(
                return errSecBadReq;
        }
 
-       size_t size = numCurves * sizeof(uint16_t);
-       ctx->ecdhCurves = (uint16_t *)sslMalloc(size);
+       if (SIZE_MAX / sizeof(*(ctx->ecdhCurves)) < (size_t)numCurves) {
+               return errSecParam;
+       }
+       ctx->ecdhCurves = sslMalloc((size_t)numCurves * sizeof(*(ctx->ecdhCurves)));
        if(ctx->ecdhCurves == NULL) {
                ctx->ecdhNumCurves = 0;
                return errSecAllocate;
        }
 
-    for (unsigned i=0; i<numCurves; i++) {
-        ctx->ecdhCurves[i] = namedCurves[i];
-    }
+       for (unsigned i=0; i<numCurves; i++) {
+               if (namedCurves[i] > UINT16_MAX - 1) {
+                       ctx->ecdhCurves[i] = SSL_Curve_None;
+                       continue;
+               }
+               ctx->ecdhCurves[i] = namedCurves[i];
+       }
 
        ctx->ecdhNumCurves = numCurves;
 
-    tls_handshake_set_curves(ctx->hdsk, ctx->ecdhCurves, ctx->ecdhNumCurves);
+       tls_handshake_set_curves(ctx->hdsk, ctx->ecdhCurves, ctx->ecdhNumCurves);
        return errSecSuccess;
 }
 
@@ -2595,3 +2772,79 @@ SSLSetSessionConfig(SSLContextRef context,
         return errSecParam;
     }
 }
+
+OSStatus
+SSLGetSessionConfigurationIdentifier(SSLContext *ctx, SSLBuffer *buffer)
+{
+    if (buffer == NULL) {
+        return errSecParam;
+    }
+
+    // Don't recompute the buffer if we've done it before and cached the result.
+    // Just copy out the result.
+    if (ctx->contextConfigurationBuffer.data != NULL) {
+        buffer->length = ctx->contextConfigurationBuffer.length;
+        buffer->data = (uint8_t *) malloc(buffer->length);
+        if (buffer->data == NULL) {
+            return errSecAllocate;
+        }
+        memcpy(buffer->data, ctx->contextConfigurationBuffer.data, buffer->length);
+        return errSecSuccess;
+    }
+
+    // Allocate the buffer, freeing up any data that was previously stored
+    // 10 here is the number of attributes we're adding below. Change it as needed.
+    buffer->length = 10 * sizeof(Boolean);
+    if (buffer->data) {
+        free(buffer->data);
+    }
+    buffer->data = malloc(buffer->length);
+    if (buffer->data == NULL) {
+        return errSecAllocate;
+    }
+
+    // Copy in the session configuration options
+    int offset = 0;
+    memcpy(buffer->data + offset, &ctx->breakOnServerAuth, sizeof(ctx->breakOnServerAuth));
+    offset += sizeof(ctx->breakOnServerAuth);
+
+    memcpy(buffer->data + offset, &ctx->breakOnCertRequest, sizeof(ctx->breakOnCertRequest));
+    offset += sizeof(ctx->breakOnCertRequest);
+
+    memcpy(buffer->data + offset, &ctx->breakOnClientAuth, sizeof(ctx->breakOnClientAuth));
+    offset += sizeof(ctx->breakOnClientAuth);
+
+    memcpy(buffer->data + offset, &ctx->signalServerAuth, sizeof(ctx->signalServerAuth));
+    offset += sizeof(ctx->signalServerAuth);
+
+    memcpy(buffer->data + offset, &ctx->signalCertRequest, sizeof(ctx->signalCertRequest));
+    offset += sizeof(ctx->signalCertRequest);
+
+    memcpy(buffer->data + offset, &ctx->signalClientAuth, sizeof(ctx->signalClientAuth));
+    offset += sizeof(ctx->signalClientAuth);
+
+    memcpy(buffer->data + offset, &ctx->breakOnClientHello, sizeof(ctx->breakOnClientHello));
+    offset += sizeof(ctx->breakOnClientHello);
+
+    memcpy(buffer->data + offset, &ctx->allowServerIdentityChange, sizeof(ctx->allowServerIdentityChange));
+    offset += sizeof(ctx->allowServerIdentityChange);
+
+    memcpy(buffer->data + offset, &ctx->allowRenegotiation, sizeof(ctx->allowRenegotiation));
+    offset += sizeof(ctx->allowRenegotiation);
+
+    memcpy(buffer->data + offset, &ctx->enableSessionTickets, sizeof(ctx->enableSessionTickets));
+    offset += sizeof(ctx->enableSessionTickets);
+
+    // Sanity check on the length
+    if (offset != buffer->length) {
+        free(buffer->data);
+        return errSecInternal;
+    }
+
+    // Save the configuration buffer for later use
+    ctx->contextConfigurationBuffer.length = buffer->length;
+    ctx->contextConfigurationBuffer.data = (uint8_t *) malloc(buffer->length);
+    memcpy(ctx->contextConfigurationBuffer.data, buffer->data, buffer->length);
+
+    return errSecSuccess;
+}