]> git.saurik.com Git - apple/security.git/blobdiff - OSX/libsecurity_ssl/regressions/ssl-49-sni.c
Security-57336.1.9.tar.gz
[apple/security.git] / OSX / libsecurity_ssl / regressions / ssl-49-sni.c
diff --git a/OSX/libsecurity_ssl/regressions/ssl-49-sni.c b/OSX/libsecurity_ssl/regressions/ssl-49-sni.c
new file mode 100644 (file)
index 0000000..0afdde5
--- /dev/null
@@ -0,0 +1,281 @@
+//
+//  ssl-49-sni.c
+//  libsecurity_ssl
+//
+//
+
+
+#include <stdbool.h>
+#include <pthread.h>
+#include <fcntl.h>
+#include <sys/mman.h>
+#include <unistd.h>
+
+#include <CoreFoundation/CoreFoundation.h>
+
+#include <AssertMacros.h>
+#include <Security/SecureTransportPriv.h> /* SSLSetOption */
+#include <Security/SecureTransport.h>
+#include <Security/SecPolicy.h>
+#include <Security/SecTrust.h>
+#include <Security/SecIdentity.h>
+#include <Security/SecIdentityPriv.h>
+#include <Security/SecCertificatePriv.h>
+#include <Security/SecKeyPriv.h>
+#include <Security/SecItem.h>
+#include <Security/SecRandom.h>
+
+#include <string.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <errno.h>
+#include <stdlib.h>
+#include <mach/mach_time.h>
+
+#if TARGET_OS_IPHONE
+#include <Security/SecRSAKey.h>
+#endif
+
+#include "ssl_regressions.h"
+#include "ssl-utils.h"
+
+typedef struct {
+    SSLContextRef handle;
+    uint32_t session_id;
+    bool is_server;
+    int comm;
+} ssl_test_handle;
+
+
+#pragma mark -
+#pragma mark SecureTransport support
+
+#if 0
+static void hexdump(const uint8_t *bytes, size_t len) {
+    size_t ix;
+    printf("socket write(%p, %lu)\n", bytes, len);
+    for (ix = 0; ix < len; ++ix) {
+        if (!(ix % 16))
+            printf("\n");
+        printf("%02X ", bytes[ix]);
+    }
+    printf("\n");
+}
+#else
+#define hexdump(bytes, len)
+#endif
+
+
+static OSStatus SocketWrite(SSLConnectionRef h, const void *data, size_t *length)
+{
+    size_t len = *length;
+    uint8_t *ptr = (uint8_t *)data;
+
+    do {
+        ssize_t ret;
+        do {
+            hexdump(ptr, len);
+            ret = write((int)h, ptr, len);
+        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
+        if (ret > 0) {
+            len -= ret;
+            ptr += ret;
+        }
+        else
+            return -36;
+    } while (len > 0);
+
+    *length = *length - len;
+    return errSecSuccess;
+}
+
+static OSStatus SocketRead(SSLConnectionRef h, void *data, size_t *length)
+{
+    size_t len = *length;
+    uint8_t *ptr = (uint8_t *)data;
+
+    do {
+        ssize_t ret;
+        do {
+            ret = read((int)h, ptr, len);
+        } while ((ret < 0) && (errno == EAGAIN || errno == EINTR));
+        if (ret > 0) {
+            len -= ret;
+            ptr += ret;
+        } else {
+            printf("read error(%d): ret=%zd, errno=%d\n", (int)h, ret, errno);
+            return -errno;
+        }
+    } while (len > 0);
+
+    *length = *length - len;
+    return errSecSuccess;
+}
+
+static char peername[] = "localhost";
+
+static void *securetransport_server_thread(void *arg)
+{
+    OSStatus ortn;
+    ssl_test_handle * ssl = (ssl_test_handle *)arg;
+    SSLContextRef ctx = ssl->handle;
+    CFArrayRef server_certs = server_chain();
+
+    do {
+        ortn = SSLHandshake(ctx);
+    } while (ortn == errSSLWouldBlock);
+
+    ok(ortn==errSSLClientHelloReceived, "Unexpected Handshake exit code");
+
+    if (ortn == errSSLClientHelloReceived) {
+        char *sni = NULL;
+        size_t length = 0;
+        SSLCopyRequestedPeerNameLength(ctx, &length);
+        if (length > 0) {
+            sni = malloc(length);
+            SSLCopyRequestedPeerName(ctx, sni, &length);
+        }
+
+        SSLProtocol version = 0;
+        require_noerr(SSLGetProtocolVersionMax(ctx, &version), out);
+        if (version == kSSLProtocol3) {
+            ok(sni==NULL, "Unexpected SNI");
+        } else {
+            ok(sni!=NULL &&
+               length == sizeof(peername) &&
+               (memcmp(sni, peername, sizeof(peername))==0),
+               "SNI does not match");
+        }
+        require_noerr(SSLSetCertificate(ctx, server_certs), out);
+    }
+
+out:
+    SSLClose(ctx);
+    SSLDisposeContext(ctx);
+    close(ssl->comm);
+
+    pthread_exit((void *)(intptr_t)ortn);
+    return NULL;
+}
+
+static void *securetransport_client_thread(void *arg)
+{
+    OSStatus ortn;
+    ssl_test_handle * ssl = (ssl_test_handle *)arg;
+    SSLContextRef ctx = ssl->handle;
+
+    do {
+        ortn = SSLHandshake(ctx);
+    } while (ortn == errSSLWouldBlock || ortn != errSSLClosedGraceful);
+
+    SSLClose(ctx);
+    SSLDisposeContext(ctx);
+    close(ssl->comm);
+
+    pthread_exit((void *)(intptr_t)ortn);
+    return NULL;
+}
+
+static SSLCipherSuite ciphers[] = {
+    TLS_RSA_WITH_AES_128_CBC_SHA,
+    //FIXME: re-enable this test when its fixed.
+    //TLS_RSA_WITH_RC4_128_SHA,
+};
+
+static ssl_test_handle *
+ssl_test_handle_create(uint32_t session_id, bool server, int comm)
+{
+    ssl_test_handle *handle = calloc(1, sizeof(ssl_test_handle));
+    SSLContextRef ctx = SSLCreateContext(kCFAllocatorDefault, server?kSSLServerSide:kSSLClientSide, kSSLStreamType);
+
+    require(handle, out);
+    require(ctx, out);
+
+    require_noerr(SSLSetIOFuncs(ctx,
+                                (SSLReadFunc)SocketRead, (SSLWriteFunc)SocketWrite), out);
+    require_noerr(SSLSetConnection(ctx, (SSLConnectionRef)(intptr_t)comm), out);
+
+    if (server)
+        require_noerr(SSLSetSessionOption(ctx,
+                                          kSSLSessionOptionBreakOnClientHello, true), out);
+    else
+        require_noerr(SSLSetSessionOption(ctx,
+                                          kSSLSessionOptionBreakOnServerAuth, true), out);
+
+    /* Tell SecureTransport to not check certs itself: it will break out of the
+     handshake to let us take care of it instead. */
+    require_noerr(SSLSetEnableCertVerify(ctx, false), out);
+
+    handle->handle = ctx;
+    handle->is_server = server;
+    handle->session_id = session_id;
+    handle->comm = comm;
+
+    return handle;
+
+out:
+   if (ctx) CFRelease(ctx);
+    return NULL;
+}
+
+static SSLProtocol versions[] = {
+    kSSLProtocol3,
+    kTLSProtocol1,
+    kTLSProtocol11,
+    kTLSProtocol12,
+};
+static int nversions = sizeof(versions)/sizeof(versions[0]);
+
+static void
+tests(void)
+{
+    int j;
+    pthread_t client_thread, server_thread;
+
+    for(j=0; j<nversions; j++)
+    {
+        int sp[2];
+        if (socketpair(AF_UNIX, SOCK_STREAM, 0, sp)) exit(errno);
+
+        ssl_test_handle *server, *client;
+
+        uint32_t session_id = (j+1) << 16 | 1 << 8;
+        server = ssl_test_handle_create(session_id, true /*server*/, sp[0]);
+        client = ssl_test_handle_create(session_id, false/*client*/, sp[1]);
+
+        require_noerr(SSLSetPeerID(server->handle, &session_id, sizeof(session_id)), out);
+        require_noerr(SSLSetPeerID(client->handle, &session_id, sizeof(session_id)), out);
+
+        /* set fixed cipher on client and server */
+        require_noerr(SSLSetEnabledCiphers(client->handle, &ciphers[0], 1), out);
+        require_noerr(SSLSetEnabledCiphers(server->handle, &ciphers[0], 1), out);
+
+        require_noerr(SSLSetProtocolVersionMax(client->handle, versions[j]), out);
+        require_noerr(SSLSetPeerDomainName(client->handle, peername, sizeof(peername)), out);
+
+        require_noerr(SSLSetProtocolVersionMax(server->handle, versions[j]), out);
+
+        pthread_create(&client_thread, NULL, securetransport_client_thread, client);
+        pthread_create(&server_thread, NULL, securetransport_server_thread, server);
+
+        int server_err, client_err;
+        pthread_join(client_thread, (void*)&client_err);
+        pthread_join(server_thread, (void*)&server_err);
+
+    out:
+        free(client);
+        free(server);
+
+    }
+}
+
+int ssl_49_sni(int argc, char *const *argv)
+{
+
+    plan_tests(8);
+
+
+    tests();
+
+    return 0;
+}