5 // Created by Fabrice Gautier on 1/6/12.
6 // Copyright (c) 2012 Apple, Inc. All rights reserved.
9 #include <Security/SecureTransportPriv.h>
11 #include <netinet/in.h>
12 #include <arpa/inet.h>
18 #include <net/kext_net.h>
20 #include "tlssocket.h"
23 #include <AssertMacros.h>
26 /* TLSSocket functions */
29 int TLSSocket_Read(SSLRecordContextRef ref
,
32 int socket
= (int)ref
;
35 struct sockaddr_in client_addr
;
42 int cbuf_len
=CMSG_SPACE(sizeof(*hdr
))+1024;
43 uint8_t cbuf
[cbuf_len
];
46 // printf("%s: Waiting for some data...\n", __FUNCTION__);
49 rc
= (int)recv(socket
, &b
, 1, MSG_PEEK
);
54 return errSSLRecordWouldBlock
;
61 /* get the next packet size */
62 avail_size
= sizeof(avail
);
63 rc
= getsockopt(socket
, SOL_SOCKET
, SO_NREAD
, &avail
, &avail_size
);
66 check(avail_size
==sizeof(avail
));
68 if(rc
|| (avail_size
!=sizeof(avail
)))
69 return errSSLRecordInternal
;
71 // printf("%s: Available = %d\n", __FUNCTION__, avail);
74 return errSSLRecordWouldBlock
;
77 /* Allocate a buffer */
78 rec
->contents
.data
= malloc(avail
);
79 rec
->contents
.length
= avail
;
81 /* read the message */
82 iov
.iov_base
= rec
->contents
.data
;
83 iov
.iov_len
= rec
->contents
.length
;
84 msg
.msg_name
= &client_addr
;
85 msg
.msg_namelen
= sizeof(client_addr
);
88 msg
.msg_control
= cbuf
;
89 msg
.msg_controllen
= cbuf_len
;
91 sz
= recvmsg(socket
, &msg
, 0);
94 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
95 rec
->contents
.length
= sz
;
97 cmsg
= CMSG_FIRSTHDR(&msg
);
102 check(cmsg
->cmsg_type
== SCM_TLS_HEADER
);
103 check(cmsg
->cmsg_level
== SOL_SOCKET
);
104 check(cmsg
->cmsg_len
== CMSG_LEN(sizeof(*hdr
)));
105 hdr
= (tls_record_hdr_t
)CMSG_DATA(cmsg
);
110 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
112 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
113 hdr->content_type, hdr->protocol_version,
114 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
116 rec
->contentType
= hdr
->content_type
;
117 rec
->protocolVersion
= hdr
->protocol_version
;
119 if(rec
->contentType
==SSL_RecordTypeChangeCipher
) {
120 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__
);
126 int TLSSocket_Free(SSLRecordContextRef ref
,
129 free(rec
.contents
.data
);
134 int TLSSocket_Write(SSLRecordContextRef ref
,
137 int socket
= (int)ref
;
142 tls_record_hdr_t hdr
;
143 struct cmsghdr
*cmsg
;
144 int cbuf_len
=CMSG_SPACE(sizeof(*hdr
));
145 uint8_t cbuf
[cbuf_len
];
147 if(rec
.contentType
==SSL_RecordTypeChangeCipher
) {
148 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__
);
150 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
152 /* write the message */
153 iov
.iov_base
= rec
.contents
.data
;
154 iov
.iov_len
= rec
.contents
.length
;
159 msg
.msg_control
= cbuf
;
160 msg
.msg_controllen
= cbuf_len
;
162 cmsg
= CMSG_FIRSTHDR(&msg
);
163 cmsg
->cmsg_level
= SOL_SOCKET
;
164 cmsg
->cmsg_type
= SCM_TLS_HEADER
;
165 cmsg
->cmsg_len
= CMSG_LEN(sizeof(*hdr
));
166 hdr
= (tls_record_hdr_t
)CMSG_DATA(cmsg
);
167 hdr
->content_type
= rec
.contentType
;
168 hdr
->protocol_version
= rec
.protocolVersion
;
171 sz
= sendmsg(socket
, &msg
, 0);
177 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
179 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
182 check(sz
==rec
.contents
.length
);
192 int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref
,
193 uint16_t selectedCipher
,
197 int socket
= (int)ref
;
201 buf
= malloc(key
.length
+3);
202 buf
[0] = selectedCipher
>> 8;
203 buf
[1] = selectedCipher
& 0xff;
205 memcpy(buf
+3, key
.data
, key
.length
);
207 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__
, selectedCipher
, key
.length
);
209 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_INIT_CIPHER
, buf
, (socklen_t
)(key
.length
+3));
211 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
219 int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref
)
221 int socket
= (int)ref
;
223 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ADVANCE_WRITE_CIPHER
, NULL
, 0);
225 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
231 int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref
)
233 int socket
= (int)ref
;
235 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ROLLBACK_WRITE_CIPHER
, NULL
, 0);
237 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
243 int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref
)
245 int socket
= (int)ref
;
247 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_ADVANCE_READ_CIPHER
, NULL
, 0);
249 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
255 int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref
,
256 SSLProtocolVersion protocolVersion
)
258 int socket
= (int)ref
;
260 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_PROTOCOL_VERSION
, &protocolVersion
, sizeof(protocolVersion
));
262 printf("%s: rc=%d\n", __FUNCTION__
, rc
);
269 int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref
)
271 int socket
= (int)ref
;
273 rc
= setsockopt(socket
, SOL_SOCKET
, SO_TLS_SERVICE_WRITE_QUEUE
, NULL
, 0);
279 const struct SSLRecordFuncs TLSSocket_Funcs
= {
280 .read
= TLSSocket_Read
,
281 .write
= TLSSocket_Write
,
282 .initPendingCiphers
= TLSSocket_InitPendingCiphers
,
283 .advanceWriteCipher
= TLSSocket_AdvanceWriteCipher
,
284 .rollbackWriteCipher
= TLSSocket_RollbackWriteCipher
,
285 .advanceReadCipher
= TLSSocket_AdvanceReadCipher
,
286 .setProtocolVersion
= TLSSocket_SetProtocolVersion
,
287 .free
= TLSSocket_Free
,
288 .serviceWriteQueue
= TLSSocket_ServiceWriteQueue
,
294 int TLSSocket_Attach(int socket
)
297 /* Attach the TLS socket filter and return handle */
298 struct so_nke so_tlsnke
;
303 memset(&so_tlsnke
, 0, sizeof(so_tlsnke
));
304 so_tlsnke
.nke_handle
= TLS_HANDLE_IP4
;
305 rc
=setsockopt(socket
, SOL_SOCKET
, SO_NKE
, &so_tlsnke
, sizeof(so_tlsnke
));
309 len
= sizeof(handle
);
310 rc
= getsockopt(socket
, SOL_SOCKET
, SO_TLS_HANDLE
, &handle
, &len
);
314 assert(len
==sizeof(handle
));