]> git.saurik.com Git - apple/security.git/blob - tlsnke/tlsnketest/tlssocket.c
Security-55471.14.tar.gz
[apple/security.git] / tlsnke / tlsnketest / tlssocket.c
1 //
2 // tlssocket.c
3 // tlsnke
4 //
5 // Created by Fabrice Gautier on 1/6/12.
6 // Copyright (c) 2012 Apple, Inc. All rights reserved.
7 //
8
9 #include <Security/SecureTransportPriv.h>
10 #include <string.h>
11 #include <netinet/in.h>
12 #include <arpa/inet.h>
13
14 #include <stdlib.h>
15 #include <stdio.h>
16 #include <assert.h>
17
18 #include <net/kext_net.h>
19
20 #include "tlssocket.h"
21 #include "tlsnke.h"
22
23 #include <AssertMacros.h>
24 #include <errno.h>
25
26 /* TLSSocket functions */
27
28 static
29 int TLSSocket_Read(SSLRecordContextRef ref,
30 SSLRecord *rec)
31 {
32 int socket = (int)ref;
33 int rc;
34 ssize_t sz;
35 struct sockaddr_in client_addr;
36 int avail;
37 socklen_t avail_size;
38 struct cmsghdr *cmsg;
39 tls_record_hdr_t hdr;
40 struct msghdr msg;
41 struct iovec iov;
42 int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
43 uint8_t cbuf[cbuf_len];
44
45
46 // printf("%s: Waiting for some data...\n", __FUNCTION__);
47 /* PEEK only... */
48 char b;
49 rc = (int)recv(socket, &b, 1, MSG_PEEK);
50
51 if(rc==-1)
52 {
53 if(errno==EAGAIN)
54 return errSSLRecordWouldBlock;
55 else {
56 perror("recv");
57 return errno;
58 }
59 }
60
61 /* get the next packet size */
62 avail_size = sizeof(avail);
63 rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
64
65 check_noerr(rc);
66 check(avail_size==sizeof(avail));
67
68 if(rc || (avail_size !=sizeof(avail)))
69 return errSSLRecordInternal;
70
71 // printf("%s: Available = %d\n", __FUNCTION__, avail);
72
73 if(avail==0)
74 return errSSLRecordWouldBlock;
75
76
77 /* Allocate a buffer */
78 rec->contents.data = malloc(avail);
79 rec->contents.length = avail;
80
81 /* read the message */
82 iov.iov_base = rec->contents.data;
83 iov.iov_len = rec->contents.length;
84 msg.msg_name = &client_addr;
85 msg.msg_namelen = sizeof(client_addr);
86 msg.msg_iov = &iov;
87 msg.msg_iovlen = 1;
88 msg.msg_control = cbuf;
89 msg.msg_controllen = cbuf_len;
90
91 sz = recvmsg(socket, &msg, 0);
92 check(sz==avail);
93
94 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
95 rec->contents.length = sz;
96
97 cmsg = CMSG_FIRSTHDR(&msg);
98 check(cmsg);
99 if(!cmsg)
100 return 0;
101
102 check(cmsg->cmsg_type == SCM_TLS_HEADER);
103 check(cmsg->cmsg_level == SOL_SOCKET);
104 check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
105 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
106 check(hdr);
107
108 /* print msg info */
109 /*
110 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
111 iov.iov_len,
112 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
113 hdr->content_type, hdr->protocol_version,
114 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
115 */
116 rec->contentType = hdr->content_type;
117 rec->protocolVersion = hdr->protocol_version;
118
119 if(rec->contentType==SSL_RecordTypeChangeCipher) {
120 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
121 }
122 return 0;
123 }
124
125 static
126 int TLSSocket_Free(SSLRecordContextRef ref,
127 SSLRecord rec)
128 {
129 free(rec.contents.data);
130 return 0;
131 }
132
133 static
134 int TLSSocket_Write(SSLRecordContextRef ref,
135 SSLRecord rec)
136 {
137 int socket = (int)ref;
138 ssize_t sz;
139
140 struct msghdr msg;
141 struct iovec iov;
142 tls_record_hdr_t hdr;
143 struct cmsghdr *cmsg;
144 int cbuf_len=CMSG_SPACE(sizeof(*hdr));
145 uint8_t cbuf[cbuf_len];
146
147 if(rec.contentType==SSL_RecordTypeChangeCipher) {
148 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
149 }
150 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
151
152 /* write the message */
153 iov.iov_base = rec.contents.data;
154 iov.iov_len = rec.contents.length;
155 msg.msg_name = NULL;
156 msg.msg_namelen = 0;
157 msg.msg_iov = &iov;
158 msg.msg_iovlen = 1;
159 msg.msg_control = cbuf;
160 msg.msg_controllen = cbuf_len;
161
162 cmsg = CMSG_FIRSTHDR(&msg);
163 cmsg->cmsg_level = SOL_SOCKET;
164 cmsg->cmsg_type = SCM_TLS_HEADER;
165 cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
166 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
167 hdr->content_type = rec.contentType;
168 hdr->protocol_version = rec.protocolVersion;
169
170 /* print msg info */
171 sz = sendmsg(socket, &msg, 0);
172
173 if(sz<0)
174 perror("sendmsg");
175
176 /*
177 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
178 iov.iov_len,
179 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
180 */
181
182 check(sz==rec.contents.length);
183
184 if(sz<0)
185 return (int)sz;
186 else
187 return 0;
188 }
189
190
191 static
192 int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref,
193 uint16_t selectedCipher,
194 bool server,
195 SSLBuffer key)
196 {
197 int socket = (int)ref;
198 int rc;
199 char *buf;
200
201 buf = malloc(key.length+3);
202 buf[0] = selectedCipher >> 8;
203 buf[1] = selectedCipher & 0xff;
204 buf[2] = server;
205 memcpy(buf+3, key.data, key.length);
206
207 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
208
209 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
210
211 printf("%s: rc=%d\n", __FUNCTION__, rc);
212
213 free(buf);
214
215 return rc;
216 }
217
218 static
219 int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
220 {
221 int socket = (int)ref;
222 int rc;
223 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
224
225 printf("%s: rc=%d\n", __FUNCTION__, rc);
226
227 return rc;
228 }
229
230 static
231 int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
232 {
233 int socket = (int)ref;
234 int rc;
235 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
236
237 printf("%s: rc=%d\n", __FUNCTION__, rc);
238
239 return rc;
240 }
241
242 static
243 int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref)
244 {
245 int socket = (int)ref;
246 int rc;
247 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
248
249 printf("%s: rc=%d\n", __FUNCTION__, rc);
250
251 return rc;
252 }
253
254 static
255 int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref,
256 SSLProtocolVersion protocolVersion)
257 {
258 int socket = (int)ref;
259 int rc;
260 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
261
262 printf("%s: rc=%d\n", __FUNCTION__, rc);
263
264 return rc;
265 }
266
267
268 static
269 int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref)
270 {
271 int socket = (int)ref;
272 int rc;
273 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
274
275 return rc;
276 }
277
278
279 const struct SSLRecordFuncs TLSSocket_Funcs = {
280 .read = TLSSocket_Read,
281 .write = TLSSocket_Write,
282 .initPendingCiphers = TLSSocket_InitPendingCiphers,
283 .advanceWriteCipher = TLSSocket_AdvanceWriteCipher,
284 .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
285 .advanceReadCipher = TLSSocket_AdvanceReadCipher,
286 .setProtocolVersion = TLSSocket_SetProtocolVersion,
287 .free = TLSSocket_Free,
288 .serviceWriteQueue = TLSSocket_ServiceWriteQueue,
289 };
290
291
292 /* TLSSocket SPIs */
293
294 int TLSSocket_Attach(int socket)
295 {
296
297 /* Attach the TLS socket filter and return handle */
298 struct so_nke so_tlsnke;
299 int rc;
300 int handle;
301 socklen_t len;
302
303 memset(&so_tlsnke, 0, sizeof(so_tlsnke));
304 so_tlsnke.nke_handle = TLS_HANDLE_IP4;
305 rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
306 if(rc)
307 return rc;
308
309 len = sizeof(handle);
310 rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
311 if(rc)
312 return rc;
313
314 assert(len==sizeof(handle));
315
316 return handle;
317 }
318