]> git.saurik.com Git - apple/security.git/blobdiff - libsecurity_ssl/lib/sslUtils.c
Security-55163.44.tar.gz
[apple/security.git] / libsecurity_ssl / lib / sslUtils.c
diff --git a/libsecurity_ssl/lib/sslUtils.c b/libsecurity_ssl/lib/sslUtils.c
new file mode 100644 (file)
index 0000000..91a4f95
--- /dev/null
@@ -0,0 +1,318 @@
+/*
+ * Copyright (c) 1999-2001,2005-2008,2010-2012 Apple Inc. All Rights Reserved.
+ *
+ * @APPLE_LICENSE_HEADER_START@
+ *
+ * This file contains Original Code and/or Modifications of Original Code
+ * as defined in and that are subject to the Apple Public Source License
+ * Version 2.0 (the 'License'). You may not use this file except in
+ * compliance with the License. Please obtain a copy of the License at
+ * http://www.opensource.apple.com/apsl/ and read it before using this
+ * file.
+ *
+ * The Original Code and all software distributed under the License are
+ * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
+ * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
+ * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
+ * Please see the License for the specific language governing rights and
+ * limitations under the License.
+ *
+ * @APPLE_LICENSE_HEADER_END@
+ */
+
+/*
+ * sslUtils.c - Misc. SSL utility functions
+ */
+
+#include "sslContext.h"
+#include "sslUtils.h"
+#include "sslMemory.h"
+#include "sslDebug.h"
+
+#include <sys/time.h>
+
+#include <fcntl.h> // for open
+#include <unistd.h> // for read
+#include <errno.h> // for errno
+#include <Security/SecFramework.h>
+#include <Security/SecRandom.h>
+
+#ifndef NDEBUG
+void SSLDump(const unsigned char *data, unsigned long len)
+{
+    unsigned long i;
+    for(i=0;i<len;i++)
+    {
+        if((i&0xf)==0) printf("%04lx :",i);
+        printf(" %02x", data[i]);
+        if((i&0xf)==0xf) printf("\n");
+    }
+    printf("\n");
+}
+#endif
+
+unsigned int
+SSLDecodeInt(const uint8_t *p, size_t length)
+{
+    unsigned int val = 0;
+    assert(length > 0 && length <= 4); //anything else would be an internal error.
+    while (length--)
+        val = (val << 8) | *p++;
+    return val;
+}
+
+uint8_t *
+SSLEncodeInt(uint8_t *p, unsigned int value, size_t length)
+{
+    unsigned char *retVal = p + length; /* Return pointer to char after int */
+    assert(length > 0 && length <= 4);  //anything else would be an internal error.
+    while (length--)                    /* Assemble backwards */
+    {   p[length] = (uint8_t)value;     /* Implicit masking to low byte */
+        value >>= 8;
+    }
+    return retVal;
+}
+
+size_t
+SSLDecodeSize(const uint8_t *p, size_t length)
+{
+    unsigned int val = 0;
+    assert(length > 0 && length <= 4); //anything else would be an internal error.
+    while (length--)
+        val = (val << 8) | *p++;
+    return val;
+}
+
+uint8_t *
+SSLEncodeSize(uint8_t *p, size_t value, size_t length)
+{
+    unsigned char *retVal = p + length; /* Return pointer to char after int */
+    assert(length > 0 && length <= 4);  //anything else would be an internal error.
+    while (length--)                    /* Assemble backwards */
+    {   p[length] = (uint8_t)value;     /* Implicit masking to low byte */
+        value >>= 8;
+    }
+    return retVal;
+}
+
+
+uint8_t *
+SSLEncodeUInt64(uint8_t *p, sslUint64 value)
+{
+    p = SSLEncodeInt(p, value.high, 4);
+    return SSLEncodeInt(p, value.low, 4);
+}
+
+
+uint8_t *
+SSLEncodeHandshakeHeader(SSLContext *ctx, SSLRecord *rec, SSLHandshakeType type, size_t msglen)
+{
+    uint8_t *charPtr;
+
+    charPtr = rec->contents.data;
+    *charPtr++ = type;
+    charPtr = SSLEncodeSize(charPtr, msglen, 3);
+
+    if(rec->protocolVersion == DTLS_Version_1_0) {
+        charPtr = SSLEncodeInt(charPtr, ctx->hdskMessageSeq, 2);
+        /* fragmentation -- we encode header as if unfragmented,
+           actual fragmentation happens at lower layer. */
+        charPtr = SSLEncodeInt(charPtr, 0, 3);
+        charPtr = SSLEncodeSize(charPtr, msglen, 3);
+    }
+
+    return charPtr;
+}
+
+
+void
+IncrementUInt64(sslUint64 *v)
+{   if (++v->low == 0)          /* Must have just rolled over */
+        ++v->high;
+}
+
+#if ENABLE_DTLS
+void
+SSLDecodeUInt64(const uint8_t *p, size_t length, sslUint64 *v)
+{
+    assert(length > 0 && length <= 8);
+    if(length<=4) {
+        v->low=SSLDecodeInt(p, length);
+        v->high=0;
+    } else {
+        v->high=SSLDecodeInt(p, length-4);
+        v->low=SSLDecodeInt(p+length-4, 4);
+    }
+}
+#endif
+
+#ifdef USE_SSLCERTIFICATE
+size_t
+SSLGetCertificateChainLength(const SSLCertificate *c)
+{
+       size_t rtn = 0;
+
+    while (c)
+    {
+       rtn++;
+        c = c->next;
+    }
+    return rtn;
+}
+
+OSStatus sslDeleteCertificateChain(
+    SSLCertificate             *certs,
+       SSLContext                      *ctx)
+{
+       SSLCertificate          *cert;
+       SSLCertificate          *nextCert;
+
+       assert(ctx != NULL);
+       cert=certs;
+       while(cert != NULL) {
+               nextCert = cert->next;
+               SSLFreeBuffer(&cert->derCert, ctx);
+               sslFree(cert);
+               cert = nextCert;
+       }
+       return noErr;
+}
+#endif /* USE_SSLCERTIFICATE */
+
+Boolean sslIsSessionActive(const SSLContext *ctx)
+{
+       assert(ctx != NULL);
+       switch(ctx->state) {
+               case SSL_HdskStateUninit:
+               case SSL_HdskStateServerUninit:
+               case SSL_HdskStateClientUninit:
+               case SSL_HdskStateGracefulClose:
+               case SSL_HdskStateErrorClose:
+                       return false;
+               default:
+                       return true;
+       }
+}
+
+#if    SSL_DEBUG
+
+const char *protocolVersStr(SSLProtocolVersion prot)
+{
+       switch(prot) {
+       case SSL_Version_Undetermined: return "SSL_Version_Undetermined";
+       case SSL_Version_2_0: return "SSL_Version_2_0";
+       case SSL_Version_3_0: return "SSL_Version_3_0";
+       case TLS_Version_1_0: return "TLS_Version_1_0";
+    case TLS_Version_1_1: return "TLS_Version_1_1";
+    case TLS_Version_1_2: return "TLS_Version_1_2";
+       default: sslErrorLog("protocolVersStr: bad prot\n"); return "BAD PROTOCOL";
+       }
+       return NULL;    /* NOT REACHED */
+}
+
+#endif /* SSL_DEBUG */
+
+/*
+ * Redirect SSLBuffer-based I/O call to user-supplied I/O.
+ */
+OSStatus sslIoRead(
+       SSLBuffer               buf,
+       size_t                  *actualLength,
+       SSLContext              *ctx)
+{
+       size_t          dataLength = buf.length;
+       OSStatus        ortn;
+
+       *actualLength = 0;
+       ortn = (ctx->ioCtx.read)(ctx->ioCtx.ioRef,
+               buf.data,
+               &dataLength);
+       *actualLength = dataLength;
+       return ortn;
+}
+
+OSStatus sslIoWrite(
+       SSLBuffer               buf,
+       size_t                  *actualLength,
+       SSLContext              *ctx)
+{
+       size_t                  dataLength = buf.length;
+       OSStatus                ortn;
+
+       *actualLength = 0;
+       ortn = (ctx->ioCtx.write)(ctx->ioCtx.ioRef,
+               buf.data,
+               &dataLength);
+       *actualLength = dataLength;
+       return ortn;
+}
+
+OSStatus sslTime(uint32_t *tim)
+{
+       time_t t;
+       time(&t);
+       *tim = (uint32_t)t;
+       return noErr;
+}
+
+/*
+ * Common RNG function.
+ */
+OSStatus sslRand(SSLContext *ctx, SSLBuffer *buf)
+{
+       assert(buf != NULL);
+       assert(buf->data != NULL);
+
+       if(buf->length == 0) {
+               sslErrorLog("sslRand: zero buf->length\n");
+               return noErr;
+       }
+
+       return SecRandomCopyBytes(kSecRandomDefault, buf->length, buf->data) ? errSSLCrypto : noErr;
+}
+
+/*
+ * Given a protocol version sent by peer, determine if we accept that version
+ * and downgrade if appropriate (which can not be done for the client side).
+ */
+OSStatus sslVerifyProtVersion(
+       SSLContext                      *ctx,
+       SSLProtocolVersion      peerVersion,    // sent by peer
+       SSLProtocolVersion      *negVersion)    // final negotiated version if return success
+{
+    if ((ctx->isDTLS)
+        ? peerVersion > ctx->minProtocolVersion
+        : peerVersion < ctx->minProtocolVersion) {
+        return errSSLNegotiation;
+    }
+    if ((ctx->isDTLS)
+        ? peerVersion < ctx->maxProtocolVersion
+        : peerVersion > ctx->maxProtocolVersion) {
+        if (ctx->protocolSide == kSSLClientSide) {
+            return errSSLNegotiation;
+        }
+        *negVersion = ctx->maxProtocolVersion;
+    } else {
+        *negVersion = peerVersion;
+    }
+
+    return noErr;
+}
+
+/*
+ * Determine max enabled protocol, i.e., the one we try to negotiate for.
+ * Only returns an error (paramErr) if NO protocols are enabled, which can
+ * in fact happen by malicious or ignorant use of SSLSetProtocolVersionEnabled().
+ */
+OSStatus sslGetMaxProtVersion(
+       SSLContext                      *ctx,
+       SSLProtocolVersion      *version)       // RETURNED
+{
+    /* This check is here until SSLSetProtocolVersionEnabled() is gone .*/
+    if (ctx->maxProtocolVersion == SSL_Version_Undetermined)
+        return badReqErr;
+
+    *version = ctx->maxProtocolVersion;
+    return noErr;
+}