--- /dev/null
+/*
+ * Copyright (c) 2012-2014 Apple Inc. All Rights Reserved.
+ *
+ * @APPLE_LICENSE_HEADER_START@
+ *
+ * This file contains Original Code and/or Modifications of Original Code
+ * as defined in and that are subject to the Apple Public Source License
+ * Version 2.0 (the 'License'). You may not use this file except in
+ * compliance with the License. Please obtain a copy of the License at
+ * http://www.opensource.apple.com/apsl/ and read it before using this
+ * file.
+ *
+ * The Original Code and all software distributed under the License are
+ * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
+ * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
+ * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
+ * Please see the License for the specific language governing rights and
+ * limitations under the License.
+ *
+ * @APPLE_LICENSE_HEADER_END@
+ */
+
+
+#include <Security/SecureTransportPriv.h>
+#include <string.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+
+#include <stdlib.h>
+#include <stdio.h>
+#include <assert.h>
+
+#include <net/kext_net.h>
+
+#include "tlssocket.h"
+#include "tlsnke.h"
+
+#include <AssertMacros.h>
+#include <errno.h>
+
+/* TLSSocket functions */
+
+static
+int TLSSocket_Read(SSLRecordContextRef ref,
+ SSLRecord *rec)
+{
+ int socket = (int)ref;
+ int rc;
+ ssize_t sz;
+ struct sockaddr_in client_addr;
+ int avail;
+ socklen_t avail_size;
+ struct cmsghdr *cmsg;
+ tls_record_hdr_t hdr;
+ struct msghdr msg;
+ struct iovec iov;
+ int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
+ uint8_t cbuf[cbuf_len];
+
+
+ // printf("%s: Waiting for some data...\n", __FUNCTION__);
+ /* PEEK only... */
+ char b;
+ rc = (int)recv(socket, &b, 1, MSG_PEEK);
+
+ if(rc==-1)
+ {
+ if(errno==EAGAIN)
+ return errSSLRecordWouldBlock;
+ else {
+ perror("recv");
+ return errno;
+ }
+ }
+
+ /* get the next packet size */
+ avail_size = sizeof(avail);
+ rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
+
+ check_noerr(rc);
+ check(avail_size==sizeof(avail));
+
+ if(rc || (avail_size !=sizeof(avail)))
+ return errSSLRecordInternal;
+
+ // printf("%s: Available = %d\n", __FUNCTION__, avail);
+
+ if(avail==0)
+ return errSSLRecordWouldBlock;
+
+
+ /* Allocate a buffer */
+ rec->contents.data = malloc(avail);
+ rec->contents.length = avail;
+
+ /* read the message */
+ iov.iov_base = rec->contents.data;
+ iov.iov_len = rec->contents.length;
+ msg.msg_name = &client_addr;
+ msg.msg_namelen = sizeof(client_addr);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cbuf;
+ msg.msg_controllen = cbuf_len;
+
+ sz = recvmsg(socket, &msg, 0);
+ check(sz==avail);
+
+ // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
+ rec->contents.length = sz;
+
+ cmsg = CMSG_FIRSTHDR(&msg);
+ check(cmsg);
+ if(!cmsg)
+ return 0;
+
+ check(cmsg->cmsg_type == SCM_TLS_HEADER);
+ check(cmsg->cmsg_level == SOL_SOCKET);
+ check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
+ hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
+ check(hdr);
+
+ /* print msg info */
+ /*
+ printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
+ iov.iov_len,
+ cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
+ hdr->content_type, hdr->protocol_version,
+ inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
+ */
+ rec->contentType = hdr->content_type;
+ rec->protocolVersion = hdr->protocol_version;
+
+ if(rec->contentType==SSL_RecordTypeChangeCipher) {
+ printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
+ }
+ return 0;
+}
+
+static
+int TLSSocket_Free(SSLRecordContextRef ref,
+ SSLRecord rec)
+{
+ free(rec.contents.data);
+ return 0;
+}
+
+static
+int TLSSocket_Write(SSLRecordContextRef ref,
+ SSLRecord rec)
+{
+ int socket = (int)ref;
+ ssize_t sz;
+
+ struct msghdr msg;
+ struct iovec iov;
+ tls_record_hdr_t hdr;
+ struct cmsghdr *cmsg;
+ int cbuf_len=CMSG_SPACE(sizeof(*hdr));
+ uint8_t cbuf[cbuf_len];
+
+ if(rec.contentType==SSL_RecordTypeChangeCipher) {
+ printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
+ }
+ // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
+
+ /* write the message */
+ iov.iov_base = rec.contents.data;
+ iov.iov_len = rec.contents.length;
+ msg.msg_name = NULL;
+ msg.msg_namelen = 0;
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = cbuf;
+ msg.msg_controllen = cbuf_len;
+
+ cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_TLS_HEADER;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
+ hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
+ hdr->content_type = rec.contentType;
+ hdr->protocol_version = rec.protocolVersion;
+
+ /* print msg info */
+ sz = sendmsg(socket, &msg, 0);
+
+ if(sz<0)
+ perror("sendmsg");
+
+ /*
+ printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
+ iov.iov_len,
+ cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
+ */
+
+ check(sz==rec.contents.length);
+
+ if(sz<0)
+ return (int)sz;
+ else
+ return 0;
+}
+
+
+static
+int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref,
+ uint16_t selectedCipher,
+ bool server,
+ SSLBuffer key)
+{
+ int socket = (int)ref;
+ int rc;
+ char *buf;
+
+ buf = malloc(key.length+3);
+ buf[0] = selectedCipher >> 8;
+ buf[1] = selectedCipher & 0xff;
+ buf[2] = server;
+ memcpy(buf+3, key.data, key.length);
+
+ printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
+
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
+
+ printf("%s: rc=%d\n", __FUNCTION__, rc);
+
+ free(buf);
+
+ return rc;
+}
+
+static
+int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
+{
+ int socket = (int)ref;
+ int rc;
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
+
+ printf("%s: rc=%d\n", __FUNCTION__, rc);
+
+ return rc;
+}
+
+static
+int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
+{
+ int socket = (int)ref;
+ int rc;
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
+
+ printf("%s: rc=%d\n", __FUNCTION__, rc);
+
+ return rc;
+}
+
+static
+int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref)
+{
+ int socket = (int)ref;
+ int rc;
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
+
+ printf("%s: rc=%d\n", __FUNCTION__, rc);
+
+ return rc;
+}
+
+static
+int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref,
+ SSLProtocolVersion protocolVersion)
+{
+ int socket = (int)ref;
+ int rc;
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
+
+ printf("%s: rc=%d\n", __FUNCTION__, rc);
+
+ return rc;
+}
+
+
+static
+int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref)
+{
+ int socket = (int)ref;
+ int rc;
+ rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
+
+ return rc;
+}
+
+
+static
+int TLSSocket_SetOption(SSLRecordContextRef ref,
+ SSLRecordOption option,
+ bool value)
+{
+ /* This is not implemented, and is not needed for DTLS */
+ return EINVAL;
+}
+
+const struct SSLRecordFuncs TLSSocket_Funcs = {
+ .read = TLSSocket_Read,
+ .write = TLSSocket_Write,
+ .initPendingCiphers = TLSSocket_InitPendingCiphers,
+ .advanceWriteCipher = TLSSocket_AdvanceWriteCipher,
+ .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
+ .advanceReadCipher = TLSSocket_AdvanceReadCipher,
+ .setProtocolVersion = TLSSocket_SetProtocolVersion,
+ .free = TLSSocket_Free,
+ .serviceWriteQueue = TLSSocket_ServiceWriteQueue,
+ .setOption = TLSSocket_SetOption,
+};
+
+
+/* TLSSocket SPIs */
+
+int TLSSocket_Attach(int socket)
+{
+
+ /* Attach the TLS socket filter and return handle */
+ struct so_nke so_tlsnke;
+ int rc;
+ int handle;
+ socklen_t len;
+
+ memset(&so_tlsnke, 0, sizeof(so_tlsnke));
+ so_tlsnke.nke_handle = TLS_HANDLE_IP4;
+ rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
+ if(rc)
+ return rc;
+
+ len = sizeof(handle);
+ rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
+ if(rc)
+ return rc;
+
+ assert(len==sizeof(handle));
+
+ return handle;
+}
+