]> git.saurik.com Git - apple/security.git/blobdiff - tlsnke/tlsnketest/dtls_client.c
Security-55471.tar.gz
[apple/security.git] / tlsnke / tlsnketest / dtls_client.c
diff --git a/tlsnke/tlsnketest/dtls_client.c b/tlsnke/tlsnketest/dtls_client.c
new file mode 100644 (file)
index 0000000..1dc76b0
--- /dev/null
@@ -0,0 +1,261 @@
+//
+//  dtls_client.c
+//  tlsnke
+//
+//  Created by Fabrice Gautier on 2/7/12.
+//  Copyright (c) 2012 Apple, Inc. All rights reserved.
+//
+
+/*
+ *  dtlsEchoClient.c
+ *  Security
+ *
+ *  Created by Fabrice Gautier on 1/31/11.
+ *  Copyright 2011 Apple, Inc. All rights reserved.
+ *
+ */
+
+#include <Security/Security.h>
+
+#include "ssl-utils.h"
+
+#include <stdlib.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <stdio.h>
+#include <errno.h>
+#include <unistd.h> /* close() */
+#include <string.h> /* memset() */
+#include <fcntl.h>
+#include <time.h>
+
+#include "tlssocket.h"
+
+#define SERVER "10.0.2.1"
+#define PORT 23232
+#define BUFLEN 128
+#define COUNT 10
+
+#if 0
+static void dumppacket(const unsigned char *data, unsigned long len)
+{
+    unsigned long i;
+    for(i=0;i<len;i++)
+    {
+        if((i&0xf)==0) printf("%04lx :",i);
+        printf(" %02x", data[i]);
+        if((i&0xf)==0xf) printf("\n");
+    }
+    printf("\n");
+}
+#endif
+
+
+/* print a '.' every few seconds to keep UI alive while connecting */
+static time_t lastTime = (time_t)0;
+#define TIME_INTERVAL          3
+
+static void sslOutputDot()
+{
+       time_t thisTime = time(0);
+       
+       if((thisTime - lastTime) >= TIME_INTERVAL) {
+               printf("."); fflush(stdout);
+               lastTime = thisTime;
+       }
+}
+
+static void printSslErrStr(
+                    const char         *op,
+                    OSStatus   err)
+{
+       printf("*** %s: %ld\n", op, (long)err);
+}
+
+/* 2K should be enough for everybody */
+#define MTU 2048
+
+
+int dtls_client(const char *hostname, int bypass);
+
+int dtls_client(const char *hostname, int bypass)
+{
+    int fd;
+    int tlsfd;
+    struct sockaddr_in sa;
+    
+    printf("Running dtls_client test with hostname=%s, bypass=%d\n", hostname, bypass);
+
+    if ((fd=socket(AF_INET, SOCK_DGRAM, 0))==-1) {
+        perror("socket");
+        exit(-1);
+    }
+    
+    memset((char *) &sa, 0, sizeof(sa));
+    sa.sin_family = AF_INET;
+    sa.sin_port = htons(PORT);
+    if (inet_aton(hostname, &sa.sin_addr)==0) {
+        fprintf(stderr, "inet_aton() failed\n");
+        exit(1);
+    }
+    
+    if(connect(fd, (struct sockaddr *)&sa, sizeof(sa))==-1)
+    {
+        perror("connect");
+        return errno;
+    }
+    
+    /* Change to non blocking io */
+    fcntl(fd, F_SETFL, O_NONBLOCK);
+    
+    SSLRecordContextRef c=(intptr_t)fd;
+    
+    
+    OSStatus            ortn;
+    SSLContextRef       ctx = NULL;
+    
+    SSLClientCertificateState certState;
+    SSLCipherSuite negCipher;
+    SSLProtocol negVersion;
+    
+       /*
+        * Set up a SecureTransport session.
+        */
+    
+    ctx = SSLCreateContextWithRecordFuncs(kCFAllocatorDefault, kSSLClientSide, kSSLDatagramType, &TLSSocket_Funcs);
+    if(!ctx) {
+        printSslErrStr("SSLCreateContextWithRecordFuncs", -1);
+        return -1;
+    }
+
+    printf("Attaching filter\n");
+    ortn = TLSSocket_Attach(fd);
+    if(ortn) {
+               printSslErrStr("TLSSocket_Attach", ortn);
+               return ortn;        
+    }
+    
+    if(bypass) {
+        tlsfd = open("/dev/tlsnke", O_RDWR);
+        if(tlsfd<0) {
+            perror("opening tlsnke dev");
+            exit(-1);
+        }
+    }
+
+    ortn = SSLSetRecordContext(ctx, c);
+       if(ortn) {
+               printSslErrStr("SSLSetRecordContext", ortn);
+               return ortn;
+       }
+    
+    ortn = SSLSetMaxDatagramRecordSize(ctx, 600);
+    if(ortn) {
+               printSslErrStr("SSLSetMaxDatagramRecordSize", ortn);
+        return ortn;
+       }
+    
+    /* Lets not verify the cert, which is a random test cert */
+    ortn = SSLSetEnableCertVerify(ctx, false);
+    if(ortn) {
+        printSslErrStr("SSLSetEnableCertVerify", ortn);
+        return ortn;
+    }
+    
+    ortn = SSLSetCertificate(ctx, server_chain());
+    if(ortn) {
+        printSslErrStr("SSLSetCertificate", ortn);
+        return ortn;
+    }
+    
+    printf("Handshake...\n");
+
+    do {
+               ortn = SSLHandshake(ctx);
+           if(ortn == errSSLWouldBlock) {
+            /* keep UI responsive */
+            sslOutputDot();
+           }
+    } while (ortn == errSSLWouldBlock);
+    
+    
+    SSLGetClientCertificateState(ctx, &certState);
+       SSLGetNegotiatedCipher(ctx, &negCipher);
+       SSLGetNegotiatedProtocolVersion(ctx, &negVersion);
+    
+    int count;
+    size_t len;
+    ssize_t sreadLen, swriteLen;
+    size_t readLen, writeLen;
+
+    char buffer[BUFLEN];
+    
+    count = 0;
+    while(count<COUNT) {
+        int timeout = 10000;
+        
+        snprintf(buffer, BUFLEN, "Message %d", count);
+        len = strlen(buffer);
+        
+        if(bypass) {
+            /* Send data through the side channel, kind of like utun would */
+            swriteLen=write(tlsfd, buffer, len);
+            if(swriteLen<0) {
+                perror("write to tlsfd");
+                break;
+            }
+            writeLen=swriteLen;
+        } else {
+            ortn=SSLWrite(ctx, buffer, len, &writeLen);
+            if(ortn) {
+                printSslErrStr("SSLWrite", ortn);
+                break;
+            }
+        }
+
+        printf("Wrote %lu bytes\n", writeLen);
+        
+        count++;
+        
+        if(bypass) {
+            do {
+                sreadLen=read(tlsfd, buffer, BUFLEN);
+            } while((sreadLen==-1) && (errno==EAGAIN) && (timeout--));
+            if((sreadLen==-1) && (errno==EAGAIN)) {
+                printf("Read timeout...\n");
+                continue;
+            }
+            if(sreadLen<0) {
+                perror("read from tlsfd");
+                break;
+            }
+            readLen=sreadLen;
+        }
+        else {
+            do {
+                ortn=SSLRead(ctx, buffer, BUFLEN, &readLen);
+            } while((ortn==errSSLWouldBlock) && (timeout--));
+            if(ortn==errSSLWouldBlock) {
+                printf("SSLRead timeout...\n");
+                continue;
+            }
+            if(ortn) {
+                printSslErrStr("SSLRead", ortn);
+                break;
+            }
+        }
+
+        buffer[readLen]=0;
+        printf("Received %lu bytes: %s\n", readLen, buffer);
+        
+    }
+    
+    SSLClose(ctx);
+    
+    SSLDisposeContext(ctx);
+    
+    return ortn;
+}
+