]> git.saurik.com Git - apple/security.git/blobdiff - SecureTransport/sslHandshakeHello.cpp
Security-54.1.3.tar.gz
[apple/security.git] / SecureTransport / sslHandshakeHello.cpp
diff --git a/SecureTransport/sslHandshakeHello.cpp b/SecureTransport/sslHandshakeHello.cpp
new file mode 100644 (file)
index 0000000..74a7681
--- /dev/null
@@ -0,0 +1,362 @@
+/*
+ * Copyright (c) 2000-2001 Apple Computer, Inc. All Rights Reserved.
+ * 
+ * The contents of this file constitute Original Code as defined in and are
+ * subject to the Apple Public Source License Version 1.2 (the 'License').
+ * You may not use this file except in compliance with the License. Please obtain
+ * a copy of the License at http://www.apple.com/publicsource and read it before
+ * using this file.
+ * 
+ * This Original Code and all software distributed under the License are
+ * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS
+ * OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES, INCLUDING WITHOUT
+ * LIMITATION, ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT. Please see the License for the
+ * specific language governing rights and limitations under the License.
+ */
+
+
+/*
+       File:           sslHandshakeHello.c
+
+       Contains:       Support for client hello and server hello messages. 
+
+       Written by:     Doug Mitchell
+
+       Copyright: (c) 1999 by Apple Computer, Inc., all rights reserved.
+
+*/
+
+#include "sslContext.h"
+#include "sslHandshake.h"
+#include "sslMemory.h"
+#include "sslSession.h"
+#include "sslUtils.h"
+#include "sslDebug.h"
+#include "appleCdsa.h"
+#include "sslDigests.h"
+#include "cipherSpecs.h"
+
+#include <string.h>
+
+static OSStatus SSLEncodeRandom(unsigned char *p, SSLContext *ctx);
+
+/* IE treats null session id as valid; two consecutive sessions with NULL ID
+ * are considered a match. Workaround: when resumable sessions are disabled, 
+ * send a random session ID. */
+#define SSL_IE_NULL_RESUME_BUG         1
+#if            SSL_IE_NULL_RESUME_BUG
+#define SSL_NULL_ID_LEN                                32      /* length of bogus session ID */
+#endif
+
+OSStatus
+SSLEncodeServerHello(SSLRecord &serverHello, SSLContext *ctx)
+{   OSStatus        err;
+    UInt8           *charPtr;
+    int             sessionIDLen;
+    
+    sessionIDLen = 0;
+    if (ctx->sessionID.data != 0)
+        sessionIDLen = (UInt8)ctx->sessionID.length;
+       #if     SSL_IE_NULL_RESUME_BUG
+       if(sessionIDLen == 0) {
+               sessionIDLen = SSL_NULL_ID_LEN;
+       }       
+       #endif  /* SSL_IE_NULL_RESUME_BUG */
+               
+       sslLogNegotiateDebug("===SSL3 server: sending version %d_%d",
+               ctx->negProtocolVersion >> 8, ctx->negProtocolVersion & 0xff);
+       sslLogNegotiateDebug("...sessionIDLen = %d", sessionIDLen);
+    serverHello.protocolVersion = ctx->negProtocolVersion;
+    serverHello.contentType = SSL_RecordTypeHandshake;
+    if ((err = SSLAllocBuffer(serverHello.contents, 42 + sessionIDLen, ctx)) != 0)
+        return err;
+    
+    charPtr = serverHello.contents.data;
+    *charPtr++ = SSL_HdskServerHello;
+    charPtr = SSLEncodeInt(charPtr, 38 + sessionIDLen, 3);
+    charPtr = SSLEncodeInt(charPtr, serverHello.protocolVersion, 2);
+    if ((err = SSLEncodeRandom(charPtr, ctx)) != 0)
+        return err;
+    memcpy(ctx->serverRandom, charPtr, SSL_CLIENT_SRVR_RAND_SIZE);
+    charPtr += SSL_CLIENT_SRVR_RAND_SIZE;
+       *(charPtr++) = (UInt8)sessionIDLen;
+       #if     SSL_IE_NULL_RESUME_BUG
+       if(ctx->sessionID.data != NULL) {
+               /* normal path for enabled resumable session */
+               memcpy(charPtr, ctx->sessionID.data, sessionIDLen);
+       }
+       else {
+               /* IE workaround */
+               SSLBuffer rb;
+               rb.data = charPtr;
+               rb.length = SSL_NULL_ID_LEN;
+               sslRand(ctx, &rb);
+       }
+       #else   
+    if (sessionIDLen > 0)
+        memcpy(charPtr, ctx->sessionID.data, sessionIDLen);
+       #endif  /* SSL_IE_NULL_RESUME_BUG */
+       charPtr += sessionIDLen;
+    charPtr = SSLEncodeInt(charPtr, ctx->selectedCipher, 2);
+    *(charPtr++) = 0;      /* Null compression */
+
+    sslLogNegotiateDebug("ssl3: server specifying cipherSuite 0x%lx", 
+               (UInt32)ctx->selectedCipher);
+       
+    assert(charPtr == serverHello.contents.data + serverHello.contents.length);
+    
+    return noErr;
+}
+
+OSStatus
+SSLProcessServerHello(SSLBuffer message, SSLContext *ctx)
+{   OSStatus            err;
+    SSLProtocolVersion  protocolVersion;
+    unsigned int        sessionIDLen;
+    UInt8               *p;
+    
+    assert(ctx->protocolSide == SSL_ClientSide);
+    
+    if (message.length < 38 || message.length > 70) {
+       sslErrorLog("SSLProcessServerHello: msg len error\n");
+        return errSSLProtocol;
+    }
+    p = message.data;
+    
+    protocolVersion = (SSLProtocolVersion)SSLDecodeInt(p, 2);
+    p += 2;
+    if (protocolVersion > ctx->maxProtocolVersion) {
+        return errSSLNegotiation;
+       }
+    ctx->negProtocolVersion = protocolVersion;
+       switch(protocolVersion) {
+               case SSL_Version_3_0:
+                       ctx->sslTslCalls = &Ssl3Callouts;
+                       break;
+               case TLS_Version_1_0:
+                       ctx->sslTslCalls = &Tls1Callouts;
+                       break;
+               default:
+                       return errSSLNegotiation;
+       }
+    sslLogNegotiateDebug("===SSL3 client: negVersion is %d_%d",
+               (protocolVersion >> 8) & 0xff, protocolVersion & 0xff);
+    
+    memcpy(ctx->serverRandom, p, 32);
+    p += 32;
+    
+    sessionIDLen = *p++;
+    if (message.length != 38 + sessionIDLen) {
+       sslErrorLog("SSLProcessServerHello: msg len error 2\n");
+        return errSSLProtocol;
+    }
+    if (sessionIDLen > 0 && ctx->peerID.data != 0)
+    {   /* Don't die on error; just treat it as an uncached session */
+        err = SSLAllocBuffer(ctx->sessionID, sessionIDLen, ctx);
+        if (err == 0)
+            memcpy(ctx->sessionID.data, p, sessionIDLen);
+    }
+    p += sessionIDLen;
+    
+    ctx->selectedCipher = (UInt16)SSLDecodeInt(p,2);
+    sslLogNegotiateDebug("===ssl3: server requests cipherKind %d", 
+       (unsigned)ctx->selectedCipher);
+    p += 2;
+    if ((err = FindCipherSpec(ctx)) != 0) {
+        return err;
+    }
+    
+    if (*p++ != 0)      /* Compression */
+        return unimpErr;
+    
+    assert(p == message.data + message.length);
+    return noErr;
+}
+
+OSStatus
+SSLEncodeClientHello(SSLRecord &clientHello, SSLContext *ctx)
+{   
+       unsigned                length, i;
+    OSStatus        err;
+    unsigned char   *p;
+    SSLBuffer       sessionIdentifier;
+    UInt16          sessionIDLen;
+    
+    assert(ctx->protocolSide == SSL_ClientSide);
+    
+    sessionIDLen = 0;
+    if (ctx->resumableSession.data != 0)
+    {   if ((err = SSLRetrieveSessionID(ctx->resumableSession,
+                               &sessionIdentifier, ctx)) != 0)
+        {   return err;
+        }
+        sessionIDLen = sessionIdentifier.length;
+    }
+    
+    length = 39 + 2*(ctx->numValidCipherSpecs) + sessionIDLen;
+    
+    clientHello.protocolVersion = ctx->maxProtocolVersion;
+    clientHello.contentType = SSL_RecordTypeHandshake;
+    if ((err = SSLAllocBuffer(clientHello.contents, length + 4, ctx)) != 0)
+        return err;
+    
+    p = clientHello.contents.data;
+    *p++ = SSL_HdskClientHello;
+    p = SSLEncodeInt(p, length, 3);
+    p = SSLEncodeInt(p, ctx->maxProtocolVersion, 2);
+       sslLogNegotiateDebug("===SSL3 client: proclaiming max protocol "
+               "%d_%d capable ONLY",
+               ctx->maxProtocolVersion >> 8, ctx->maxProtocolVersion & 0xff);
+   if ((err = SSLEncodeRandom(p, ctx)) != 0)
+    {   SSLFreeBuffer(clientHello.contents, ctx);
+        return err;
+    }
+    memcpy(ctx->clientRandom, p, SSL_CLIENT_SRVR_RAND_SIZE);
+    p += 32;
+    *p++ = sessionIDLen;                               /* 1 byte vector length */
+    if (sessionIDLen > 0)
+    {   memcpy(p, sessionIdentifier.data, sessionIDLen);
+        if ((err = SSLFreeBuffer(sessionIdentifier, ctx)) != 0)
+            return err;
+    }
+    p += sessionIDLen;
+    p = SSLEncodeInt(p, 2*(ctx->numValidCipherSpecs), 2);  
+                                                                                       /* 2 byte long vector length */
+    for (i = 0; i<ctx->numValidCipherSpecs; ++i)
+        p = SSLEncodeInt(p, ctx->validCipherSpecs[i].cipherSpec, 2);
+    *p++ = 1;                               /* 1 byte long vector */
+    *p++ = 0;                               /* null compression */
+    
+    assert(p == clientHello.contents.data + clientHello.contents.length);
+    
+    if ((err = SSLInitMessageHashes(ctx)) != 0)
+        return err;
+    
+    return noErr;
+}
+
+OSStatus
+SSLProcessClientHello(SSLBuffer message, SSLContext *ctx)
+{   OSStatus            err;
+    SSLProtocolVersion  clientVersion;
+    UInt16              cipherListLen, cipherCount, desiredSpec, cipherSpec;
+    UInt8               sessionIDLen, compressionCount;
+    UInt8               *charPtr;
+    unsigned            i;
+    
+    if (message.length < 41) {
+       sslErrorLog("SSLProcessClientHello: msg len error 1\n");
+        return errSSLProtocol;
+    }
+    charPtr = message.data;
+    clientVersion = (SSLProtocolVersion)SSLDecodeInt(charPtr, 2);
+    charPtr += 2;
+       if(clientVersion > ctx->maxProtocolVersion) {
+               clientVersion = ctx->maxProtocolVersion;
+       }
+       switch(clientVersion) {
+               case SSL_Version_3_0:
+                       ctx->sslTslCalls = &Ssl3Callouts;
+                       break;
+               case TLS_Version_1_0:
+                       ctx->sslTslCalls = &Tls1Callouts;
+                       break;
+               default:
+                       return errSSLNegotiation;
+       }
+       ctx->negProtocolVersion = clientVersion;
+    sslLogNegotiateDebug("===SSL3 server: negVersion is %d_%d",
+               clientVersion >> 8, clientVersion & 0xff);
+    
+    memcpy(ctx->clientRandom, charPtr, SSL_CLIENT_SRVR_RAND_SIZE);
+    charPtr += 32;
+    sessionIDLen = *(charPtr++);
+    if (message.length < (unsigned)(41 + sessionIDLen)) {
+       sslErrorLog("SSLProcessClientHello: msg len error 2\n");
+        return errSSLProtocol;
+    }
+    if (sessionIDLen > 0 && ctx->peerID.data != 0)
+    {   /* Don't die on error; just treat it as an uncacheable session */
+        err = SSLAllocBuffer(ctx->sessionID, sessionIDLen, ctx);
+        if (err == 0)
+            memcpy(ctx->sessionID.data, charPtr, sessionIDLen);
+    }
+    charPtr += sessionIDLen;
+    
+    cipherListLen = (UInt16)SSLDecodeInt(charPtr, 2);  
+                                                               /* Count of cipherSpecs, must be even & >= 2 */
+    charPtr += 2;
+    if ((cipherListLen & 1) || 
+           (cipherListLen < 2) || 
+               (message.length < (unsigned)(39 + sessionIDLen + cipherListLen))) {
+       sslErrorLog("SSLProcessClientHello: msg len error 3\n");
+        return errSSLProtocol;
+    }
+    cipherCount = cipherListLen/2;
+    cipherSpec = 0xFFFF;        /* No match marker */
+    while (cipherSpec == 0xFFFF && cipherCount--)
+    {   desiredSpec = (UInt16)SSLDecodeInt(charPtr, 2);
+        charPtr += 2;
+        for (i = 0; i <ctx->numValidCipherSpecs; i++)
+        {   if (ctx->validCipherSpecs[i].cipherSpec == desiredSpec)
+            {   cipherSpec = desiredSpec;
+                break;
+            }
+        }
+    }
+    
+    if (cipherSpec == 0xFFFF)
+        return errSSLNegotiation;
+    charPtr += 2 * cipherCount;    /* Advance past unchecked cipherCounts */
+    ctx->selectedCipher = cipherSpec;
+    if ((err = FindCipherSpec(ctx)) != 0) {
+        return err;
+    }
+    sslLogNegotiateDebug("ssl3 server: selecting cipherKind 0x%x", (unsigned)ctx->selectedCipher);
+    
+    compressionCount = *(charPtr++);
+    if ((compressionCount < 1) || 
+           (message.length < 
+                   (unsigned)(38 + sessionIDLen + cipherListLen + compressionCount))) {
+       sslErrorLog("SSLProcessClientHello: msg len error 4\n");
+        return errSSLProtocol;
+    }
+    /* Ignore list; we're doing null */
+    
+    if ((err = SSLInitMessageHashes(ctx)) != 0)
+        return err;
+    
+    return noErr;
+}
+
+static OSStatus
+SSLEncodeRandom(unsigned char *p, SSLContext *ctx)
+{   SSLBuffer   randomData;
+    OSStatus    err;
+    UInt32      time;
+    
+    if ((err = sslTime(&time)) != 0)
+        return err;
+    SSLEncodeInt(p, time, 4);
+    randomData.data = p+4;
+    randomData.length = 28;
+       if((err = sslRand(ctx, &randomData)) != 0)
+        return err;
+    return noErr;
+}
+
+OSStatus
+SSLInitMessageHashes(SSLContext *ctx)
+{   OSStatus          err;
+
+    if ((err = CloseHash(SSLHashSHA1, ctx->shaState, ctx)) != 0)
+        return err;
+    if ((err = CloseHash(SSLHashMD5,  ctx->md5State, ctx)) != 0)
+        return err;
+    if ((err = ReadyHash(SSLHashSHA1, ctx->shaState, ctx)) != 0)
+        return err;
+    if ((err = ReadyHash(SSLHashMD5,  ctx->md5State, ctx)) != 0)
+        return err;
+    return noErr;
+}