]> git.saurik.com Git - apple/security.git/blobdiff - Security/tlsnke/tlsnketest/tlssocket.c
Security-57031.1.35.tar.gz
[apple/security.git] / Security / tlsnke / tlsnketest / tlssocket.c
diff --git a/Security/tlsnke/tlsnketest/tlssocket.c b/Security/tlsnke/tlsnketest/tlssocket.c
new file mode 100644 (file)
index 0000000..6f0e662
--- /dev/null
@@ -0,0 +1,344 @@
+/*
+ * Copyright (c) 2012-2014 Apple Inc. All Rights Reserved.
+ *
+ * @APPLE_LICENSE_HEADER_START@
+ * 
+ * This file contains Original Code and/or Modifications of Original Code
+ * as defined in and that are subject to the Apple Public Source License
+ * Version 2.0 (the 'License'). You may not use this file except in
+ * compliance with the License. Please obtain a copy of the License at
+ * http://www.opensource.apple.com/apsl/ and read it before using this
+ * file.
+ * 
+ * The Original Code and all software distributed under the License are
+ * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
+ * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
+ * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
+ * Please see the License for the specific language governing rights and
+ * limitations under the License.
+ * 
+ * @APPLE_LICENSE_HEADER_END@
+ */
+
+
+#include <Security/SecureTransportPriv.h>
+#include <string.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+
+#include <stdlib.h>
+#include <stdio.h>
+#include <assert.h>
+
+#include <net/kext_net.h>
+
+#include "tlssocket.h"
+#include "tlsnke.h"
+
+#include <AssertMacros.h>
+#include <errno.h>
+
+/* TLSSocket functions */
+
+static 
+int TLSSocket_Read(SSLRecordContextRef ref,
+                        SSLRecord *rec)
+{
+    int socket = (int)ref;
+    int rc;
+    ssize_t sz;
+    struct sockaddr_in client_addr;
+    int avail;
+    socklen_t avail_size;
+    struct cmsghdr *cmsg;
+    tls_record_hdr_t hdr;
+    struct msghdr msg;
+    struct iovec iov;
+    int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
+    uint8_t cbuf[cbuf_len];
+   
+
+    //    printf("%s: Waiting for some data...\n", __FUNCTION__);
+    /* PEEK only... */
+    char b;
+    rc = (int)recv(socket, &b, 1, MSG_PEEK);
+    
+    if(rc==-1)
+    {
+        if(errno==EAGAIN)
+            return errSSLRecordWouldBlock;
+        else {
+            perror("recv");
+            return errno;
+        }
+    }
+    
+    /* get the next packet size */
+    avail_size = sizeof(avail);
+    rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
+    
+    check_noerr(rc); 
+    check(avail_size==sizeof(avail));
+    
+    if(rc || (avail_size !=sizeof(avail)))
+        return errSSLRecordInternal;
+
+    //    printf("%s: Available = %d\n", __FUNCTION__, avail);
+    
+    if(avail==0)
+        return errSSLRecordWouldBlock;
+
+        
+    /* Allocate a buffer */
+    rec->contents.data = malloc(avail);
+    rec->contents.length = avail;
+    
+    /* read the message */
+    iov.iov_base = rec->contents.data;
+    iov.iov_len = rec->contents.length;
+    msg.msg_name = &client_addr;
+    msg.msg_namelen = sizeof(client_addr);
+    msg.msg_iov = &iov;
+    msg.msg_iovlen = 1;
+    msg.msg_control = cbuf;
+    msg.msg_controllen = cbuf_len;
+    
+    sz = recvmsg(socket, &msg, 0);
+    check(sz==avail);
+    
+    //    printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
+    rec->contents.length = sz;
+    
+    cmsg = CMSG_FIRSTHDR(&msg);
+    check(cmsg);
+    if(!cmsg)
+        return 0;
+    
+    check(cmsg->cmsg_type == SCM_TLS_HEADER);
+    check(cmsg->cmsg_level == SOL_SOCKET);
+    check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
+    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
+    check(hdr);
+    
+    /* print msg info */
+    /*
+    printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
+           iov.iov_len,
+           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
+           hdr->content_type, hdr->protocol_version,
+           inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port)); 
+    */
+    rec->contentType = hdr->content_type;
+    rec->protocolVersion = hdr->protocol_version;
+    
+    if(rec->contentType==SSL_RecordTypeChangeCipher) {
+        printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
+    }
+    return 0;
+}
+
+static
+int TLSSocket_Free(SSLRecordContextRef ref,
+                         SSLRecord rec)
+{
+    free(rec.contents.data);
+    return 0;
+}
+
+static 
+int TLSSocket_Write(SSLRecordContextRef ref,
+                          SSLRecord rec)
+{
+    int socket = (int)ref;
+    ssize_t sz;
+    
+    struct msghdr msg;
+    struct iovec iov;
+    tls_record_hdr_t hdr;
+    struct cmsghdr *cmsg;
+    int cbuf_len=CMSG_SPACE(sizeof(*hdr));
+    uint8_t cbuf[cbuf_len];
+
+    if(rec.contentType==SSL_RecordTypeChangeCipher) {
+        printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
+    }
+    // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
+
+    /* write the message */
+    iov.iov_base = rec.contents.data;
+    iov.iov_len = rec.contents.length;
+    msg.msg_name = NULL;
+    msg.msg_namelen = 0;
+    msg.msg_iov = &iov;
+    msg.msg_iovlen = 1;
+    msg.msg_control = cbuf;
+    msg.msg_controllen = cbuf_len;
+
+    cmsg = CMSG_FIRSTHDR(&msg);
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_type = SCM_TLS_HEADER;
+    cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
+    hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
+    hdr->content_type = rec.contentType;
+    hdr->protocol_version = rec.protocolVersion;
+    
+    /* print msg info */
+    sz = sendmsg(socket, &msg, 0);
+    
+    if(sz<0)
+        perror("sendmsg");
+    
+    /*
+       printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
+           iov.iov_len,
+           cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
+    */
+    
+    check(sz==rec.contents.length);
+
+    if(sz<0)
+        return (int)sz;
+    else
+        return 0;
+}
+
+
+static
+int TLSSocket_InitPendingCiphers(SSLRecordContextRef   ref,
+                                       uint16_t              selectedCipher,
+                                       bool                  server,
+                                       SSLBuffer             key)
+{
+    int socket = (int)ref;
+    int rc;
+    char *buf;
+    
+    buf = malloc(key.length+3);
+    buf[0] = selectedCipher >> 8;
+    buf[1] = selectedCipher & 0xff;
+    buf[2] = server;
+    memcpy(buf+3, key.data, key.length);
+    
+    printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
+    
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
+    
+    printf("%s: rc=%d\n", __FUNCTION__, rc);
+    
+    free(buf);
+    
+    return rc;
+}
+
+static 
+int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
+{
+    int socket = (int)ref;
+    int rc;
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
+    
+    printf("%s: rc=%d\n", __FUNCTION__, rc);
+    
+    return rc;
+}
+
+static 
+int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
+{
+    int socket = (int)ref;
+    int rc;
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
+    
+    printf("%s: rc=%d\n", __FUNCTION__, rc);
+    
+    return rc;
+}
+
+static 
+int TLSSocket_AdvanceReadCipher(SSLRecordContextRef    ref)
+{
+    int socket = (int)ref;
+    int rc;
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
+    
+    printf("%s: rc=%d\n", __FUNCTION__, rc);
+    
+    return rc;
+}
+
+static 
+int TLSSocket_SetProtocolVersion(SSLRecordContextRef    ref,
+                                 SSLProtocolVersion     protocolVersion)
+{
+    int socket = (int)ref;
+    int rc;
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
+    
+    printf("%s: rc=%d\n", __FUNCTION__, rc);
+    
+    return rc;
+}
+
+
+static
+int TLSSocket_ServiceWriteQueue(SSLRecordContextRef    ref)
+{
+    int socket = (int)ref;
+    int rc;
+    rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
+
+    return rc;
+}
+
+
+static
+int TLSSocket_SetOption(SSLRecordContextRef    ref,
+                        SSLRecordOption        option,
+                        bool                   value)
+{
+    /* This is not implemented, and is not needed for DTLS */
+    return EINVAL;
+}
+
+const struct SSLRecordFuncs TLSSocket_Funcs = {
+    .read                = TLSSocket_Read,
+    .write               = TLSSocket_Write,
+    .initPendingCiphers  = TLSSocket_InitPendingCiphers,
+    .advanceWriteCipher  = TLSSocket_AdvanceWriteCipher,
+    .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
+    .advanceReadCipher   = TLSSocket_AdvanceReadCipher,
+    .setProtocolVersion  = TLSSocket_SetProtocolVersion,
+    .free                = TLSSocket_Free,
+    .serviceWriteQueue   = TLSSocket_ServiceWriteQueue,
+    .setOption           = TLSSocket_SetOption,
+};
+
+
+/* TLSSocket SPIs */
+
+int TLSSocket_Attach(int socket)
+{
+    
+    /* Attach the TLS socket filter and return handle */
+    struct so_nke so_tlsnke;
+    int rc;
+    int handle;
+    socklen_t len;
+    
+    memset(&so_tlsnke, 0, sizeof(so_tlsnke));
+    so_tlsnke.nke_handle = TLS_HANDLE_IP4;
+    rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
+    if(rc)
+        return rc;
+
+    len = sizeof(handle);
+    rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
+    if(rc)
+        return rc;
+
+    assert(len==sizeof(handle));
+    
+    return handle;
+}
+