2 * Copyright (c) 2012-2014 Apple Inc. All Rights Reserved.
4 * @APPLE_LICENSE_HEADER_START@
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
21 * @APPLE_LICENSE_HEADER_END@
25 #include <Security/SecureTransportPriv.h>
27 #include <netinet/in.h>
28 #include <arpa/inet.h>
34 #include <net/kext_net.h>
36 #include "tlssocket.h"
39 #include <AssertMacros.h>
42 /* TLSSocket functions */
45 int TLSSocket_Read(SSLRecordContextRef ref
,
48 int socket
= (int)ref
;
51 struct sockaddr_in client_addr
;
58 int cbuf_len
=CMSG_SPACE(sizeof(*hdr
))+1024;
59 uint8_t cbuf
[cbuf_len
];
62 // printf("%s: Waiting for some data...\n", __FUNCTION__);
65 rc
= (int)recv(socket
, &b
, 1, MSG_PEEK
);
70 return errSSLRecordWouldBlock
;
77 /* get the next packet size */
78 avail_size
= sizeof(avail
);
79 rc
= getsockopt(socket
, SOL_SOCKET
, SO_NREAD
, &avail
, &avail_size
);
82 check(avail_size
==sizeof(avail
));
84 if(rc
|| (avail_size
!=sizeof(avail
)))
85 return errSSLRecordInternal
;
87 // printf("%s: Available = %d\n", __FUNCTION__, avail);
90 return errSSLRecordWouldBlock
;
93 /* Allocate a buffer */
94 rec
->contents
.data
= malloc(avail
);
95 rec
->contents
.length
= avail
;
97 /* read the message */
98 iov
.iov_base
= rec
->contents
.data
;
99 iov
.iov_len
= rec
->contents
.length
;
100 msg
.msg_name
= &client_addr
;
101 msg
.msg_namelen
= sizeof(client_addr
);
104 msg
.msg_control
= cbuf
;
105 msg
.msg_controllen
= cbuf_len
;
107 sz
= recvmsg(socket
, &msg
, 0);
110 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
111 rec
->contents
.length
= sz
;
113 cmsg
= CMSG_FIRSTHDR(&msg
);
118 check(cmsg
->cmsg_type
== SCM_TLS_HEADER
);
119 check(cmsg
->cmsg_level
== SOL_SOCKET
);
120 check(cmsg
->cmsg_len
== CMSG_LEN(sizeof(*hdr
)));
121 hdr
= (tls_record_hdr_t
)CMSG_DATA(cmsg
);
126 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
128 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
129 hdr->content_type, hdr->protocol_version,
130 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
132 rec
->contentType
= hdr
->content_type
;
133 rec
->protocolVersion
= hdr
->protocol_version
;
135 if(rec
->contentType
==SSL_RecordTypeChangeCipher
) {
136 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__
);
142 int TLSSocket_Free(SSLRecordContextRef ref
,
145 free(rec
.contents
.data
);
150 int TLSSocket_Write(SSLRecordContextRef ref
,
153 int socket
= (int)ref
;
158 tls_record_hdr_t hdr
;
159 struct cmsghdr
*cmsg
;
160 int cbuf_len
=CMSG_SPACE(sizeof(*hdr
));
161 uint8_t cbuf
[cbuf_len
];
163 if(rec
.contentType
==SSL_RecordTypeChangeCipher
) {
164 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__
);
166 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
168 /* write the message */
169 iov
.iov_base
= rec
.contents
.data
;
170 iov
.iov_len
= rec
.contents
.length
;
175 msg
.msg_control
= cbuf
;
176 msg
.msg_controllen
= cbuf_len
;
178 cmsg
= CMSG_FIRSTHDR(&msg
);
179 cmsg
->cmsg_level
= SOL_SOCKET
;
180 cmsg
->cmsg_type
= SCM_TLS_HEADER
;
181 cmsg
->cmsg_len
= CMSG_LEN(sizeof(*hdr
));
182 hdr
= (tls_record_hdr_t
)CMSG_DATA(cmsg
);
183 hdr
->content_type
= rec
.contentType
;
184 hdr
->protocol_version
= rec
.protocolVersion
;
187 sz
= sendmsg(socket
, &msg
, 0);
193 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
195 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
198 check(sz
==rec
.contents
.length
);
208 int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref
,
209 uint16_t selectedCipher
,
213 int socket
= (int)ref
;
217 buf
= malloc(key
.length
+3);
218 buf
[0] = selectedCipher
>> 8;
219 buf
[1] = selectedCipher
& 0xff;
221 memcpy(buf
+3, key
.data
, key
.length
);
223 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__
, selectedCipher
, key
.length
);
225 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_INIT_CIPHER
, buf
, (socklen_t
)(key
.length
+3));
227 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
235 int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref
)
237 int socket
= (int)ref
;
239 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ADVANCE_WRITE_CIPHER
, NULL
, 0);
241 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
247 int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref
)
249 int socket
= (int)ref
;
251 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ROLLBACK_WRITE_CIPHER
, NULL
, 0);
253 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
259 int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref
)
261 int socket
= (int)ref
;
263 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ADVANCE_READ_CIPHER
, NULL
, 0);
265 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
271 int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref
,
272 SSLProtocolVersion protocolVersion
)
274 int socket
= (int)ref
;
276 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_PROTOCOL_VERSION
, &protocolVersion
, sizeof(protocolVersion
));
278 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
285 int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref
)
287 int socket
= (int)ref
;
289 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_SERVICE_WRITE_QUEUE
, NULL
, 0);
296 int TLSSocket_SetOption(SSLRecordContextRef ref
,
297 SSLRecordOption option
,
300 /* This is not implemented, and is not needed for DTLS */
304 const struct SSLRecordFuncs TLSSocket_Funcs
= {
305 .read
= TLSSocket_Read
,
306 .write
= TLSSocket_Write
,
307 .initPendingCiphers
= TLSSocket_InitPendingCiphers
,
308 .advanceWriteCipher
= TLSSocket_AdvanceWriteCipher
,
309 .rollbackWriteCipher
= TLSSocket_RollbackWriteCipher
,
310 .advanceReadCipher
= TLSSocket_AdvanceReadCipher
,
311 .setProtocolVersion
= TLSSocket_SetProtocolVersion
,
312 .free
= TLSSocket_Free
,
313 .serviceWriteQueue
= TLSSocket_ServiceWriteQueue
,
314 .setOption
= TLSSocket_SetOption
,
320 int TLSSocket_Attach(int socket
)
323 /* Attach the TLS socket filter and return handle */
324 struct so_nke so_tlsnke
;
329 memset(&so_tlsnke
, 0, sizeof(so_tlsnke
));
330 so_tlsnke
.nke_handle
= TLS_HANDLE_IP4
;
331 rc
=setsockopt(socket
, SOL_SOCKET
, SO_NKE
, &so_tlsnke
, sizeof(so_tlsnke
));
335 len
= sizeof(handle
);
336 rc
= getsockopt(socket
, SOL_SOCKET
, SO_TLS_HANDLE
, &handle
, &len
);
340 assert(len
==sizeof(handle
));