]> git.saurik.com Git - apple/security.git/blob - Security/tlsnke/tlsnketest/tlssocket.c
Security-57031.10.10.tar.gz
[apple/security.git] / Security / tlsnke / tlsnketest / tlssocket.c
1 /*
2 * Copyright (c) 2012-2014 Apple Inc. All Rights Reserved.
3 *
4 * @APPLE_LICENSE_HEADER_START@
5 *
6 * This file contains Original Code and/or Modifications of Original Code
7 * as defined in and that are subject to the Apple Public Source License
8 * Version 2.0 (the 'License'). You may not use this file except in
9 * compliance with the License. Please obtain a copy of the License at
10 * http://www.opensource.apple.com/apsl/ and read it before using this
11 * file.
12 *
13 * The Original Code and all software distributed under the License are
14 * distributed on an 'AS IS' basis, WITHOUT WARRANTY OF ANY KIND, EITHER
15 * EXPRESS OR IMPLIED, AND APPLE HEREBY DISCLAIMS ALL SUCH WARRANTIES,
16 * INCLUDING WITHOUT LIMITATION, ANY WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE, QUIET ENJOYMENT OR NON-INFRINGEMENT.
18 * Please see the License for the specific language governing rights and
19 * limitations under the License.
20 *
21 * @APPLE_LICENSE_HEADER_END@
22 */
23
24
25 #include <Security/SecureTransportPriv.h>
26 #include <string.h>
27 #include <netinet/in.h>
28 #include <arpa/inet.h>
29
30 #include <stdlib.h>
31 #include <stdio.h>
32 #include <assert.h>
33
34 #include <net/kext_net.h>
35
36 #include "tlssocket.h"
37 #include "tlsnke.h"
38
39 #include <AssertMacros.h>
40 #include <errno.h>
41
42 /* TLSSocket functions */
43
44 static
45 int TLSSocket_Read(SSLRecordContextRef ref,
46 SSLRecord *rec)
47 {
48 int socket = (int)ref;
49 int rc;
50 ssize_t sz;
51 struct sockaddr_in client_addr;
52 int avail;
53 socklen_t avail_size;
54 struct cmsghdr *cmsg;
55 tls_record_hdr_t hdr;
56 struct msghdr msg;
57 struct iovec iov;
58 int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
59 uint8_t cbuf[cbuf_len];
60
61
62 // printf("%s: Waiting for some data...\n", __FUNCTION__);
63 /* PEEK only... */
64 char b;
65 rc = (int)recv(socket, &b, 1, MSG_PEEK);
66
67 if(rc==-1)
68 {
69 if(errno==EAGAIN)
70 return errSSLRecordWouldBlock;
71 else {
72 perror("recv");
73 return errno;
74 }
75 }
76
77 /* get the next packet size */
78 avail_size = sizeof(avail);
79 rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
80
81 check_noerr(rc);
82 check(avail_size==sizeof(avail));
83
84 if(rc || (avail_size !=sizeof(avail)))
85 return errSSLRecordInternal;
86
87 // printf("%s: Available = %d\n", __FUNCTION__, avail);
88
89 if(avail==0)
90 return errSSLRecordWouldBlock;
91
92
93 /* Allocate a buffer */
94 rec->contents.data = malloc(avail);
95 rec->contents.length = avail;
96
97 /* read the message */
98 iov.iov_base = rec->contents.data;
99 iov.iov_len = rec->contents.length;
100 msg.msg_name = &client_addr;
101 msg.msg_namelen = sizeof(client_addr);
102 msg.msg_iov = &iov;
103 msg.msg_iovlen = 1;
104 msg.msg_control = cbuf;
105 msg.msg_controllen = cbuf_len;
106
107 sz = recvmsg(socket, &msg, 0);
108 check(sz==avail);
109
110 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
111 rec->contents.length = sz;
112
113 cmsg = CMSG_FIRSTHDR(&msg);
114 check(cmsg);
115 if(!cmsg)
116 return 0;
117
118 check(cmsg->cmsg_type == SCM_TLS_HEADER);
119 check(cmsg->cmsg_level == SOL_SOCKET);
120 check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
121 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
122 check(hdr);
123
124 /* print msg info */
125 /*
126 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
127 iov.iov_len,
128 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
129 hdr->content_type, hdr->protocol_version,
130 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
131 */
132 rec->contentType = hdr->content_type;
133 rec->protocolVersion = hdr->protocol_version;
134
135 if(rec->contentType==SSL_RecordTypeChangeCipher) {
136 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
137 }
138 return 0;
139 }
140
141 static
142 int TLSSocket_Free(SSLRecordContextRef ref,
143 SSLRecord rec)
144 {
145 free(rec.contents.data);
146 return 0;
147 }
148
149 static
150 int TLSSocket_Write(SSLRecordContextRef ref,
151 SSLRecord rec)
152 {
153 int socket = (int)ref;
154 ssize_t sz;
155
156 struct msghdr msg;
157 struct iovec iov;
158 tls_record_hdr_t hdr;
159 struct cmsghdr *cmsg;
160 int cbuf_len=CMSG_SPACE(sizeof(*hdr));
161 uint8_t cbuf[cbuf_len];
162
163 if(rec.contentType==SSL_RecordTypeChangeCipher) {
164 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
165 }
166 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
167
168 /* write the message */
169 iov.iov_base = rec.contents.data;
170 iov.iov_len = rec.contents.length;
171 msg.msg_name = NULL;
172 msg.msg_namelen = 0;
173 msg.msg_iov = &iov;
174 msg.msg_iovlen = 1;
175 msg.msg_control = cbuf;
176 msg.msg_controllen = cbuf_len;
177
178 cmsg = CMSG_FIRSTHDR(&msg);
179 cmsg->cmsg_level = SOL_SOCKET;
180 cmsg->cmsg_type = SCM_TLS_HEADER;
181 cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
182 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
183 hdr->content_type = rec.contentType;
184 hdr->protocol_version = rec.protocolVersion;
185
186 /* print msg info */
187 sz = sendmsg(socket, &msg, 0);
188
189 if(sz<0)
190 perror("sendmsg");
191
192 /*
193 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
194 iov.iov_len,
195 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
196 */
197
198 check(sz==rec.contents.length);
199
200 if(sz<0)
201 return (int)sz;
202 else
203 return 0;
204 }
205
206
207 static
208 int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref,
209 uint16_t selectedCipher,
210 bool server,
211 SSLBuffer key)
212 {
213 int socket = (int)ref;
214 int rc;
215 char *buf;
216
217 buf = malloc(key.length+3);
218 buf[0] = selectedCipher >> 8;
219 buf[1] = selectedCipher & 0xff;
220 buf[2] = server;
221 memcpy(buf+3, key.data, key.length);
222
223 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
224
225 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
226
227 printf("%s: rc=%d\n", __FUNCTION__, rc);
228
229 free(buf);
230
231 return rc;
232 }
233
234 static
235 int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
236 {
237 int socket = (int)ref;
238 int rc;
239 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
240
241 printf("%s: rc=%d\n", __FUNCTION__, rc);
242
243 return rc;
244 }
245
246 static
247 int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
248 {
249 int socket = (int)ref;
250 int rc;
251 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
252
253 printf("%s: rc=%d\n", __FUNCTION__, rc);
254
255 return rc;
256 }
257
258 static
259 int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref)
260 {
261 int socket = (int)ref;
262 int rc;
263 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
264
265 printf("%s: rc=%d\n", __FUNCTION__, rc);
266
267 return rc;
268 }
269
270 static
271 int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref,
272 SSLProtocolVersion protocolVersion)
273 {
274 int socket = (int)ref;
275 int rc;
276 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
277
278 printf("%s: rc=%d\n", __FUNCTION__, rc);
279
280 return rc;
281 }
282
283
284 static
285 int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref)
286 {
287 int socket = (int)ref;
288 int rc;
289 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
290
291 return rc;
292 }
293
294
295 static
296 int TLSSocket_SetOption(SSLRecordContextRef ref,
297 SSLRecordOption option,
298 bool value)
299 {
300 /* This is not implemented, and is not needed for DTLS */
301 return EINVAL;
302 }
303
304 const struct SSLRecordFuncs TLSSocket_Funcs = {
305 .read = TLSSocket_Read,
306 .write = TLSSocket_Write,
307 .initPendingCiphers = TLSSocket_InitPendingCiphers,
308 .advanceWriteCipher = TLSSocket_AdvanceWriteCipher,
309 .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
310 .advanceReadCipher = TLSSocket_AdvanceReadCipher,
311 .setProtocolVersion = TLSSocket_SetProtocolVersion,
312 .free = TLSSocket_Free,
313 .serviceWriteQueue = TLSSocket_ServiceWriteQueue,
314 .setOption = TLSSocket_SetOption,
315 };
316
317
318 /* TLSSocket SPIs */
319
320 int TLSSocket_Attach(int socket)
321 {
322
323 /* Attach the TLS socket filter and return handle */
324 struct so_nke so_tlsnke;
325 int rc;
326 int handle;
327 socklen_t len;
328
329 memset(&so_tlsnke, 0, sizeof(so_tlsnke));
330 so_tlsnke.nke_handle = TLS_HANDLE_IP4;
331 rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
332 if(rc)
333 return rc;
334
335 len = sizeof(handle);
336 rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
337 if(rc)
338 return rc;
339
340 assert(len==sizeof(handle));
341
342 return handle;
343 }
344