]> git.saurik.com Git - apple/security.git/blobdiff - SecureTransport/ssl2Record.cpp
Security-54.1.3.tar.gz
[apple/security.git] / SecureTransport / ssl2Record.cpp
diff --git a/SecureTransport/ssl2Record.cpp b/SecureTransport/ssl2Record.cpp
new file mode 100644 (file)
index 0000000..6cfaafa
--- /dev/null
@@ -0,0 +1,420 @@
+/*
+ * Copyright (c) 2000-2001 Apple Computer, Inc. All Rights Reserved.
+ * 
+ * The contents of this file constitute Original Code as defined in and are
+ * subject to the Apple Public Source License Version 1.2 (the 'License').
+ * You may not use this file except in compliance with the License. Please obtain
+ * a copy of the License at http://www.apple.com/publicsource and read it before
+ * using this file.
+ * 
+ * This Original Code and all software distributed under the License are
+ * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS
+ * OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES, INCLUDING WITHOUT
+ * LIMITATION, ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ * PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT. Please see the License for the
+ * specific language governing rights and limitations under the License.
+ */
+
+
+/*
+       File:           ssl2Record.cpp
+
+       Contains:       Record encrypting/decrypting/MACing for SSL 2
+
+       Written by:     Doug Mitchell
+
+       Copyright: (c) 1999 by Apple Computer, Inc., all rights reserved.
+
+*/
+
+#include "ssl2.h"
+#include "sslRecord.h"
+#include "sslMemory.h"
+#include "sslContext.h"
+#include "sslAlertMessage.h"
+#include "sslDebug.h"
+#include "sslUtils.h"
+#include "sslDigests.h"
+
+#include <string.h>
+
+static OSStatus SSL2DecryptRecord(
+       SSLBuffer &payload, 
+       SSLContext *ctx);
+static OSStatus SSL2VerifyMAC(
+       SSLBuffer &content, 
+       UInt8 *compareMAC, 
+       SSLContext *ctx);
+static OSStatus SSL2CalculateMAC(
+       SSLBuffer &secret, 
+       SSLBuffer &content, 
+       UInt32 seqNo, 
+       const HashReference &hash, 
+       SSLBuffer &mac, 
+       SSLContext *ctx);
+
+
+OSStatus
+SSL2ReadRecord(SSLRecord &rec, SSLContext *ctx)
+{   OSStatus        err;
+    UInt32          len, contentLen;
+    int             padding, headerSize;
+    UInt8           *charPtr;
+    SSLBuffer       readData, cipherFragment;
+    
+    switch (ctx->negProtocolVersion)
+    {   case SSL_Version_Undetermined:
+        case SSL_Version_3_0_With_2_0_Hello:
+        case SSL_Version_2_0:
+            break;
+        case SSL_Version_3_0:           /* We've negotiated a 3.0 session; 
+                                                                                * we can send an alert */
+               case TLS_Version_1_0:
+            SSLFatalSessionAlert(SSL_AlertUnexpectedMsg, ctx);
+            return errSSLProtocol;
+        case SSL_Version_3_0_Only:      /* We haven't yet negotiated, but 
+                                                                                * we don't want to support 2.0; just 
+                                                                                * die without an alert */
+            return errSSLProtocol;
+        default:
+            sslErrorLog("bad protocolVersion in ctx->protocolVersion");
+                       return errSSLInternal;
+    }
+    
+    if (!ctx->partialReadBuffer.data || ctx->partialReadBuffer.length < 3)
+    {   if (ctx->partialReadBuffer.data)
+            if ((err = SSLFreeBuffer(ctx->partialReadBuffer, ctx)) != 0)
+            {   SSLFatalSessionAlert(SSL_AlertCloseNotify, ctx);
+                return err;
+            }
+        if ((err = SSLAllocBuffer(ctx->partialReadBuffer, DEFAULT_BUFFER_SIZE, ctx)) != 0)
+        {   SSLFatalSessionAlert(SSL_AlertCloseNotify, ctx);
+            return err;
+        }
+    }
+    
+    if (ctx->amountRead < 3)
+    {   readData.length = 3 - ctx->amountRead;
+        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
+        len = readData.length;
+        err = sslIoRead(readData, &len, ctx);
+        if(err != 0)
+        {   if (err == errSSLWouldBlock)
+                ctx->amountRead += len;
+            if (err == ioErr && ctx->amountRead == 0)    /* If the session closes on a record boundary, it's graceful */
+                err = errSSLClosedGraceful;
+            return err;
+        }
+        ctx->amountRead += len;
+    }
+    
+    rec.contentType = SSL_RecordTypeV2_0;
+    rec.protocolVersion = SSL_Version_2_0;
+    charPtr = ctx->partialReadBuffer.data;
+    
+    if (((*charPtr) & 0x80) != 0)       /* High bit on -> specifies 2-byte header */
+    {   headerSize = 2;
+        contentLen = ((charPtr[0] & 0x7F) << 8) | charPtr[1];
+        padding = 0;
+    }
+    else if (((*charPtr) & 0x40) != 0) /* Bit 6 on -> specifies security escape */
+    {   return errSSLProtocol;          /* No security escapes are defined */
+    }
+    else                                /* 3-byte header */
+    {   headerSize = 3;
+        contentLen = ((charPtr[0] & 0x3F) << 8) | charPtr[1];
+        padding = charPtr[2];
+    }
+    
+    /* 
+     * FIXME - what's the max record size?
+     * and why doesn't SSLReadRecord parse the 2 or 3 byte header?
+        * Note: I see contentLen of 0 coming back from www.cduniverse.com when
+        * it's only been given SSL_RSA_EXPORT_WITH_DES40_CBC_SHA.
+     */
+    if((contentLen == 0) || (contentLen > 0xffff)) {
+       return errSSLProtocol;
+    }
+    
+    charPtr += headerSize;
+    
+    if (ctx->partialReadBuffer.length < headerSize + contentLen)
+    {   if ((err = SSLReallocBuffer(ctx->partialReadBuffer, 5 + contentLen, ctx)) != 0)
+            return err;
+    }
+    
+    if (ctx->amountRead < headerSize + contentLen)
+    {   readData.length = headerSize + contentLen - ctx->amountRead;
+        readData.data = ctx->partialReadBuffer.data + ctx->amountRead;
+        len = readData.length;
+        err = sslIoRead(readData, &len, ctx);
+        if(err != 0)
+        {   if (err == errSSLWouldBlock)
+                ctx->amountRead += len;
+            return err;
+        }
+        ctx->amountRead += len;
+    }
+    
+    cipherFragment.data = ctx->partialReadBuffer.data + headerSize;
+    cipherFragment.length = contentLen;
+    if ((err = SSL2DecryptRecord(cipherFragment, ctx)) != 0)
+        return err;
+    
+    cipherFragment.length -= padding;       /* Remove padding; MAC was removed 
+                                                                                        * by SSL2DecryptRecord */
+    
+    IncrementUInt64(&ctx->readCipher.sequenceNum);
+    
+       /* Allocate a buffer to return the plaintext in and return it */
+    if ((err = SSLAllocBuffer(rec.contents, cipherFragment.length, ctx)) != 0)
+        return err;
+    memcpy(rec.contents.data, cipherFragment.data, cipherFragment.length);
+    
+    ctx->amountRead = 0;        /* We've used all the data in the cache */
+    
+    return noErr;
+}
+
+OSStatus
+SSL2WriteRecord(SSLRecord &rec, SSLContext *ctx)
+{   OSStatus        err;
+    int             padding = 0, i, headerSize;
+    WaitingRecord   *out, *queue;
+    SSLBuffer       buf, content, payload, secret, mac;
+    UInt8           *charPtr;
+    UInt16          payloadSize, blockSize;
+    
+    assert(rec.contents.length < 16384);
+    
+    out = 0;
+    /* Allocate a WaitingRecord to store our ready-to-send record in */
+    if ((err = SSLAllocBuffer(buf, sizeof(WaitingRecord), ctx)) != 0)
+        return err;
+    out = (WaitingRecord*)buf.data;
+    out->next = 0;
+    out->sent = 0;
+        
+    payloadSize = (UInt16) 
+               (rec.contents.length + ctx->writeCipher.macRef->hash->digestSize);
+    blockSize = ctx->writeCipher.symCipher->blockSize;
+    if (blockSize > 0)
+    {   
+               padding = blockSize - (payloadSize % blockSize);
+        if (padding == blockSize)
+            padding = 0;
+        payloadSize += padding;
+        headerSize = 3;
+    }
+    else
+    {   padding = 0;
+        headerSize = 2;
+    }
+    out->data.data = 0;
+    if ((err = SSLAllocBuffer(out->data, headerSize + payloadSize, ctx)) != 0)
+        goto fail;
+    charPtr = out->data.data;
+    
+    if (headerSize == 2)
+        charPtr = SSLEncodeInt(charPtr, payloadSize | 0x8000, 2);
+    else
+    {   charPtr = SSLEncodeInt(charPtr, payloadSize, 2);
+        *charPtr++ = padding;
+    }
+    
+    payload.data = charPtr;
+    payload.length = payloadSize;
+    
+    mac.data = charPtr;
+    mac.length = ctx->writeCipher.macRef->hash->digestSize;
+    charPtr += mac.length;
+    
+    content.data = charPtr;
+    content.length = rec.contents.length + padding;
+    memcpy(charPtr, rec.contents.data, rec.contents.length);
+    charPtr += rec.contents.length;
+    i = padding;
+    while (i--)
+        *charPtr++ = padding;
+
+    assert(charPtr == out->data.data + out->data.length);
+    
+    secret.data = ctx->writeCipher.macSecret;
+    secret.length = ctx->writeCipher.symCipher->keySize;
+    if (mac.length > 0)
+        if ((err = SSL2CalculateMAC(secret, content, 
+                               ctx->writeCipher.sequenceNum.low,
+                *ctx->writeCipher.macRef->hash, mac, ctx)) != 0)
+            goto fail;
+    
+    if ((err = ctx->writeCipher.symCipher->encrypt(payload, 
+               payload, 
+               &ctx->writeCipher, 
+               ctx)) != 0)
+        goto fail;
+    
+    /* Enqueue the record to be written from the idle loop */
+    if (ctx->recordWriteQueue == 0)
+        ctx->recordWriteQueue = out;
+    else
+    {   queue = ctx->recordWriteQueue;
+        while (queue->next != 0)
+            queue = queue->next;
+        queue->next = out;
+    }
+    
+    /* Increment the sequence number */
+    IncrementUInt64(&ctx->writeCipher.sequenceNum);
+    
+    return noErr;
+    
+fail:   
+       /* 
+        * Only for if we fail between when the WaitingRecord is allocated and 
+        * when it is queued 
+        */
+    SSLFreeBuffer(out->data, 0);
+    buf.data = (UInt8*)out;
+    buf.length = sizeof(WaitingRecord);
+    SSLFreeBuffer(buf, ctx);
+    return err;
+}
+
+static OSStatus
+SSL2DecryptRecord(SSLBuffer &payload, SSLContext *ctx)
+{   OSStatus        err;
+    SSLBuffer       content;
+    
+    if (ctx->readCipher.symCipher->blockSize > 0)
+        if (payload.length % ctx->readCipher.symCipher->blockSize != 0)
+            return errSSLProtocol;
+    
+       /* Decrypt in place */
+    if ((err = ctx->readCipher.symCipher->decrypt(payload, 
+               payload, 
+               &ctx->readCipher, 
+               ctx)) != 0)
+        return err;
+    
+    if (ctx->readCipher.macRef->hash->digestSize > 0)       
+               /* Optimize away MAC for null case */
+    {   content.data = payload.data + ctx->readCipher.macRef->hash->digestSize;                        /* Data is after MAC */
+        content.length = payload.length - ctx->readCipher.macRef->hash->digestSize;
+        if ((err = SSL2VerifyMAC(content, payload.data, ctx)) != 0)
+            return err;
+               /* Adjust payload to remove MAC; caller is still responsible 
+                * for removing padding [if any] */
+        payload = content;
+    }
+    
+    return noErr;
+}
+
+#define IGNORE_MAC_FAILURE     0
+
+static OSStatus
+SSL2VerifyMAC(SSLBuffer &content, UInt8 *compareMAC, SSLContext *ctx)
+{   OSStatus    err;
+    UInt8       calculatedMAC[SSL_MAX_DIGEST_LEN];
+    SSLBuffer   secret, mac;
+    
+    secret.data = ctx->readCipher.macSecret;
+    secret.length = ctx->readCipher.symCipher->keySize;
+    mac.data = calculatedMAC;
+    mac.length = ctx->readCipher.macRef->hash->digestSize;
+    if ((err = SSL2CalculateMAC(secret, content, ctx->readCipher.sequenceNum.low,
+                                *ctx->readCipher.macRef->hash, mac, ctx)) != 0)
+        return err;
+    if (memcmp(mac.data, compareMAC, mac.length) != 0) {
+               #if     IGNORE_MAC_FAILURE
+               sslErrorLog("SSL2VerifyMAC: Mac verify failure\n");
+               return noErr;
+               #else
+               sslErrorLog("SSL2VerifyMAC: Mac verify failure\n");
+        return errSSLProtocol;
+        #endif
+    }
+    return noErr;
+}
+
+#define LOG_MAC_DATA           0
+#if            LOG_MAC_DATA
+static void logMacData(
+       char *field,
+       SSLBuffer *data)
+{
+       int i;
+       
+       printf("%s: ", field);
+       for(i=0; i<data->length; i++) {
+               printf("%02X", data->data[i]);
+               if((i % 4) == 3) {
+                       printf(" ");
+               }
+       }
+       printf("\n");
+}
+#else  /* LOG_MAC_DATA */
+#define logMacData(f, d)
+#endif /* LOG_MAC_DATA */
+
+/* For SSL 2, the MAC is hash ( secret || content || sequence# )
+ *  where secret is the decryption key for the message, content is
+ *  the record data plus any padding used to round out the record
+ *  size to an even multiple of the block size and sequence# is
+ *  a monotonically increasing 32-bit unsigned integer.
+ */
+static OSStatus
+SSL2CalculateMAC(
+       SSLBuffer &secret, 
+       SSLBuffer &content, 
+       UInt32 seqNo, 
+       const HashReference &hash, 
+       SSLBuffer &mac, 
+       SSLContext *ctx)
+{   OSStatus    err;
+    UInt8       sequenceNum[4];
+    SSLBuffer   seqData, hashContext;
+    
+    SSLEncodeInt(sequenceNum, seqNo, 4);
+    seqData.data = sequenceNum;
+    seqData.length = 4;
+       
+    hashContext.data = 0;
+    if ((err = ReadyHash(hash, hashContext, ctx)) != 0)
+        return err;
+    if ((err = hash.update(hashContext, secret)) != 0)
+        goto fail;
+    if ((err = hash.update(hashContext, content)) != 0)
+        goto fail;
+    if ((err = hash.update(hashContext, seqData)) != 0)
+        goto fail;
+    if ((err = hash.final(hashContext, mac)) != 0)
+        goto fail;
+
+       logMacData("secret ", &secret);
+       logMacData("seqData", &seqData);
+       logMacData("mac    ", &mac);
+    
+    err = noErr;
+fail:
+    SSLFreeBuffer(hashContext, ctx);
+    return err;
+}
+
+OSStatus
+SSL2SendError(SSL2ErrorCode error, SSLContext *ctx)
+{   OSStatus        err;
+    SSLRecord       rec;
+    UInt8           errorData[3];
+    
+    rec.contentType = SSL_RecordTypeV2_0;
+    rec.protocolVersion = SSL_Version_2_0;
+    rec.contents.data = errorData;
+    rec.contents.length = 3;
+    errorData[0] = SSL2_MsgError;
+    SSLEncodeInt(errorData + 1, error, 2);
+    
+    err = SSL2WriteRecord(rec, ctx);
+    return err;
+}