]> git.saurik.com Git - apple/security.git/blame - tlsnke/tlsnketest/tlssocket.c
Security-55471.14.tar.gz
[apple/security.git] / tlsnke / tlsnketest / tlssocket.c
CommitLineData
427c49bc
A
1//
2// tlssocket.c
3// tlsnke
4//
5// Created by Fabrice Gautier on 1/6/12.
6// Copyright (c) 2012 Apple, Inc. All rights reserved.
7//
8
9#include <Security/SecureTransportPriv.h>
10#include <string.h>
11#include <netinet/in.h>
12#include <arpa/inet.h>
13
14#include <stdlib.h>
15#include <stdio.h>
16#include <assert.h>
17
18#include <net/kext_net.h>
19
20#include "tlssocket.h"
21#include "tlsnke.h"
22
23#include <AssertMacros.h>
24#include <errno.h>
25
26/* TLSSocket functions */
27
28static
29int TLSSocket_Read(SSLRecordContextRef ref,
30 SSLRecord *rec)
31{
32 int socket = (int)ref;
33 int rc;
34 ssize_t sz;
35 struct sockaddr_in client_addr;
36 int avail;
37 socklen_t avail_size;
38 struct cmsghdr *cmsg;
39 tls_record_hdr_t hdr;
40 struct msghdr msg;
41 struct iovec iov;
42 int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024;
43 uint8_t cbuf[cbuf_len];
44
45
46 // printf("%s: Waiting for some data...\n", __FUNCTION__);
47 /* PEEK only... */
48 char b;
49 rc = (int)recv(socket, &b, 1, MSG_PEEK);
50
51 if(rc==-1)
52 {
53 if(errno==EAGAIN)
54 return errSSLRecordWouldBlock;
55 else {
56 perror("recv");
57 return errno;
58 }
59 }
60
61 /* get the next packet size */
62 avail_size = sizeof(avail);
63 rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size);
64
65 check_noerr(rc);
66 check(avail_size==sizeof(avail));
67
68 if(rc || (avail_size !=sizeof(avail)))
69 return errSSLRecordInternal;
70
71 // printf("%s: Available = %d\n", __FUNCTION__, avail);
72
73 if(avail==0)
74 return errSSLRecordWouldBlock;
75
76
77 /* Allocate a buffer */
78 rec->contents.data = malloc(avail);
79 rec->contents.length = avail;
80
81 /* read the message */
82 iov.iov_base = rec->contents.data;
83 iov.iov_len = rec->contents.length;
84 msg.msg_name = &client_addr;
85 msg.msg_namelen = sizeof(client_addr);
86 msg.msg_iov = &iov;
87 msg.msg_iovlen = 1;
88 msg.msg_control = cbuf;
89 msg.msg_controllen = cbuf_len;
90
91 sz = recvmsg(socket, &msg, 0);
92 check(sz==avail);
93
94 // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags);
95 rec->contents.length = sz;
96
97 cmsg = CMSG_FIRSTHDR(&msg);
98 check(cmsg);
99 if(!cmsg)
100 return 0;
101
102 check(cmsg->cmsg_type == SCM_TLS_HEADER);
103 check(cmsg->cmsg_level == SOL_SOCKET);
104 check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr)));
105 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
106 check(hdr);
107
108 /* print msg info */
109 /*
110 printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc,
111 iov.iov_len,
112 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type,
113 hdr->content_type, hdr->protocol_version,
114 inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port));
115 */
116 rec->contentType = hdr->content_type;
117 rec->protocolVersion = hdr->protocol_version;
118
119 if(rec->contentType==SSL_RecordTypeChangeCipher) {
120 printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__);
121 }
122 return 0;
123}
124
125static
126int TLSSocket_Free(SSLRecordContextRef ref,
127 SSLRecord rec)
128{
129 free(rec.contents.data);
130 return 0;
131}
132
133static
134int TLSSocket_Write(SSLRecordContextRef ref,
135 SSLRecord rec)
136{
137 int socket = (int)ref;
138 ssize_t sz;
139
140 struct msghdr msg;
141 struct iovec iov;
142 tls_record_hdr_t hdr;
143 struct cmsghdr *cmsg;
144 int cbuf_len=CMSG_SPACE(sizeof(*hdr));
145 uint8_t cbuf[cbuf_len];
146
147 if(rec.contentType==SSL_RecordTypeChangeCipher) {
148 printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__);
149 }
150 // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length);
151
152 /* write the message */
153 iov.iov_base = rec.contents.data;
154 iov.iov_len = rec.contents.length;
155 msg.msg_name = NULL;
156 msg.msg_namelen = 0;
157 msg.msg_iov = &iov;
158 msg.msg_iovlen = 1;
159 msg.msg_control = cbuf;
160 msg.msg_controllen = cbuf_len;
161
162 cmsg = CMSG_FIRSTHDR(&msg);
163 cmsg->cmsg_level = SOL_SOCKET;
164 cmsg->cmsg_type = SCM_TLS_HEADER;
165 cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr));
166 hdr = (tls_record_hdr_t)CMSG_DATA(cmsg);
167 hdr->content_type = rec.contentType;
168 hdr->protocol_version = rec.protocolVersion;
169
170 /* print msg info */
171 sz = sendmsg(socket, &msg, 0);
172
173 if(sz<0)
174 perror("sendmsg");
175
176 /*
177 printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz,
178 iov.iov_len,
179 cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type);
180 */
181
182 check(sz==rec.contents.length);
183
184 if(sz<0)
185 return (int)sz;
186 else
187 return 0;
188}
189
190
191static
192int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref,
193 uint16_t selectedCipher,
194 bool server,
195 SSLBuffer key)
196{
197 int socket = (int)ref;
198 int rc;
199 char *buf;
200
201 buf = malloc(key.length+3);
202 buf[0] = selectedCipher >> 8;
203 buf[1] = selectedCipher & 0xff;
204 buf[2] = server;
205 memcpy(buf+3, key.data, key.length);
206
207 printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length);
208
209 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3));
210
211 printf("%s: rc=%d\n", __FUNCTION__, rc);
212
213 free(buf);
214
215 return rc;
216}
217
218static
219int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref)
220{
221 int socket = (int)ref;
222 int rc;
223 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0);
224
225 printf("%s: rc=%d\n", __FUNCTION__, rc);
226
227 return rc;
228}
229
230static
231int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref)
232{
233 int socket = (int)ref;
234 int rc;
235 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0);
236
237 printf("%s: rc=%d\n", __FUNCTION__, rc);
238
239 return rc;
240}
241
242static
243int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref)
244{
245 int socket = (int)ref;
246 int rc;
247 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0);
248
249 printf("%s: rc=%d\n", __FUNCTION__, rc);
250
251 return rc;
252}
253
254static
255int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref,
256 SSLProtocolVersion protocolVersion)
257{
258 int socket = (int)ref;
259 int rc;
260 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion));
261
262 printf("%s: rc=%d\n", __FUNCTION__, rc);
263
264 return rc;
265}
266
267
268static
269int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref)
270{
271 int socket = (int)ref;
272 int rc;
273 rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0);
274
275 return rc;
276}
277
278
279const struct SSLRecordFuncs TLSSocket_Funcs = {
280 .read = TLSSocket_Read,
281 .write = TLSSocket_Write,
282 .initPendingCiphers = TLSSocket_InitPendingCiphers,
283 .advanceWriteCipher = TLSSocket_AdvanceWriteCipher,
284 .rollbackWriteCipher = TLSSocket_RollbackWriteCipher,
285 .advanceReadCipher = TLSSocket_AdvanceReadCipher,
286 .setProtocolVersion = TLSSocket_SetProtocolVersion,
287 .free = TLSSocket_Free,
288 .serviceWriteQueue = TLSSocket_ServiceWriteQueue,
289};
290
291
292/* TLSSocket SPIs */
293
294int TLSSocket_Attach(int socket)
295{
296
297 /* Attach the TLS socket filter and return handle */
298 struct so_nke so_tlsnke;
299 int rc;
300 int handle;
301 socklen_t len;
302
303 memset(&so_tlsnke, 0, sizeof(so_tlsnke));
304 so_tlsnke.nke_handle = TLS_HANDLE_IP4;
305 rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke));
306 if(rc)
307 return rc;
308
309 len = sizeof(handle);
310 rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len);
311 if(rc)
312 return rc;
313
314 assert(len==sizeof(handle));
315
316 return handle;
317}
318