]>
Commit | Line | Data |
---|---|---|
427c49bc A |
1 | // |
2 | // tlssocket.c | |
3 | // tlsnke | |
4 | // | |
5 | // Created by Fabrice Gautier on 1/6/12. | |
6 | // Copyright (c) 2012 Apple, Inc. All rights reserved. | |
7 | // | |
8 | ||
9 | #include <Security/SecureTransportPriv.h> | |
10 | #include <string.h> | |
11 | #include <netinet/in.h> | |
12 | #include <arpa/inet.h> | |
13 | ||
14 | #include <stdlib.h> | |
15 | #include <stdio.h> | |
16 | #include <assert.h> | |
17 | ||
18 | #include <net/kext_net.h> | |
19 | ||
20 | #include "tlssocket.h" | |
21 | #include "tlsnke.h" | |
22 | ||
23 | #include <AssertMacros.h> | |
24 | #include <errno.h> | |
25 | ||
26 | /* TLSSocket functions */ | |
27 | ||
28 | static | |
29 | int TLSSocket_Read(SSLRecordContextRef ref, | |
30 | SSLRecord *rec) | |
31 | { | |
32 | int socket = (int)ref; | |
33 | int rc; | |
34 | ssize_t sz; | |
35 | struct sockaddr_in client_addr; | |
36 | int avail; | |
37 | socklen_t avail_size; | |
38 | struct cmsghdr *cmsg; | |
39 | tls_record_hdr_t hdr; | |
40 | struct msghdr msg; | |
41 | struct iovec iov; | |
42 | int cbuf_len=CMSG_SPACE(sizeof(*hdr))+1024; | |
43 | uint8_t cbuf[cbuf_len]; | |
44 | ||
45 | ||
46 | // printf("%s: Waiting for some data...\n", __FUNCTION__); | |
47 | /* PEEK only... */ | |
48 | char b; | |
49 | rc = (int)recv(socket, &b, 1, MSG_PEEK); | |
50 | ||
51 | if(rc==-1) | |
52 | { | |
53 | if(errno==EAGAIN) | |
54 | return errSSLRecordWouldBlock; | |
55 | else { | |
56 | perror("recv"); | |
57 | return errno; | |
58 | } | |
59 | } | |
60 | ||
61 | /* get the next packet size */ | |
62 | avail_size = sizeof(avail); | |
63 | rc = getsockopt(socket, SOL_SOCKET, SO_NREAD, &avail, &avail_size); | |
64 | ||
65 | check_noerr(rc); | |
66 | check(avail_size==sizeof(avail)); | |
67 | ||
68 | if(rc || (avail_size !=sizeof(avail))) | |
69 | return errSSLRecordInternal; | |
70 | ||
71 | // printf("%s: Available = %d\n", __FUNCTION__, avail); | |
72 | ||
73 | if(avail==0) | |
74 | return errSSLRecordWouldBlock; | |
75 | ||
76 | ||
77 | /* Allocate a buffer */ | |
78 | rec->contents.data = malloc(avail); | |
79 | rec->contents.length = avail; | |
80 | ||
81 | /* read the message */ | |
82 | iov.iov_base = rec->contents.data; | |
83 | iov.iov_len = rec->contents.length; | |
84 | msg.msg_name = &client_addr; | |
85 | msg.msg_namelen = sizeof(client_addr); | |
86 | msg.msg_iov = &iov; | |
87 | msg.msg_iovlen = 1; | |
88 | msg.msg_control = cbuf; | |
89 | msg.msg_controllen = cbuf_len; | |
90 | ||
91 | sz = recvmsg(socket, &msg, 0); | |
92 | check(sz==avail); | |
93 | ||
94 | // printf("%s: received = %ld, ctrl: l=%d f=%x\n", __FUNCTION__, sz, msg.msg_controllen, msg.msg_flags); | |
95 | rec->contents.length = sz; | |
96 | ||
97 | cmsg = CMSG_FIRSTHDR(&msg); | |
98 | check(cmsg); | |
99 | if(!cmsg) | |
100 | return 0; | |
101 | ||
102 | check(cmsg->cmsg_type == SCM_TLS_HEADER); | |
103 | check(cmsg->cmsg_level == SOL_SOCKET); | |
104 | check(cmsg->cmsg_len == CMSG_LEN(sizeof(*hdr))); | |
105 | hdr = (tls_record_hdr_t)CMSG_DATA(cmsg); | |
106 | check(hdr); | |
107 | ||
108 | /* print msg info */ | |
109 | /* | |
110 | printf("%s: rc=%d, msg: %ld , cmsg = %d, %x, %x, hdr = %d, %x - from %s:%d\n", __FUNCTION__, rc, | |
111 | iov.iov_len, | |
112 | cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type, | |
113 | hdr->content_type, hdr->protocol_version, | |
114 | inet_ntoa(client_addr.sin_addr),ntohs(client_addr.sin_port)); | |
115 | */ | |
116 | rec->contentType = hdr->content_type; | |
117 | rec->protocolVersion = hdr->protocol_version; | |
118 | ||
119 | if(rec->contentType==SSL_RecordTypeChangeCipher) { | |
120 | printf("%s: Received ChangeCipherSpec message\n", __FUNCTION__); | |
121 | } | |
122 | return 0; | |
123 | } | |
124 | ||
125 | static | |
126 | int TLSSocket_Free(SSLRecordContextRef ref, | |
127 | SSLRecord rec) | |
128 | { | |
129 | free(rec.contents.data); | |
130 | return 0; | |
131 | } | |
132 | ||
133 | static | |
134 | int TLSSocket_Write(SSLRecordContextRef ref, | |
135 | SSLRecord rec) | |
136 | { | |
137 | int socket = (int)ref; | |
138 | ssize_t sz; | |
139 | ||
140 | struct msghdr msg; | |
141 | struct iovec iov; | |
142 | tls_record_hdr_t hdr; | |
143 | struct cmsghdr *cmsg; | |
144 | int cbuf_len=CMSG_SPACE(sizeof(*hdr)); | |
145 | uint8_t cbuf[cbuf_len]; | |
146 | ||
147 | if(rec.contentType==SSL_RecordTypeChangeCipher) { | |
148 | printf("%s: Sending ChangeCipherSpec message\n", __FUNCTION__); | |
149 | } | |
150 | // printf("%s: fd=%d, rec.len=%ld\n", __FUNCTION__, socket, rec.contents.length); | |
151 | ||
152 | /* write the message */ | |
153 | iov.iov_base = rec.contents.data; | |
154 | iov.iov_len = rec.contents.length; | |
155 | msg.msg_name = NULL; | |
156 | msg.msg_namelen = 0; | |
157 | msg.msg_iov = &iov; | |
158 | msg.msg_iovlen = 1; | |
159 | msg.msg_control = cbuf; | |
160 | msg.msg_controllen = cbuf_len; | |
161 | ||
162 | cmsg = CMSG_FIRSTHDR(&msg); | |
163 | cmsg->cmsg_level = SOL_SOCKET; | |
164 | cmsg->cmsg_type = SCM_TLS_HEADER; | |
165 | cmsg->cmsg_len = CMSG_LEN(sizeof(*hdr)); | |
166 | hdr = (tls_record_hdr_t)CMSG_DATA(cmsg); | |
167 | hdr->content_type = rec.contentType; | |
168 | hdr->protocol_version = rec.protocolVersion; | |
169 | ||
170 | /* print msg info */ | |
171 | sz = sendmsg(socket, &msg, 0); | |
172 | ||
173 | if(sz<0) | |
174 | perror("sendmsg"); | |
175 | ||
176 | /* | |
177 | printf("%s: sz=%ld, msg: %ld , cmsg = %d, %d, %04x\n", __FUNCTION__, sz, | |
178 | iov.iov_len, | |
179 | cmsg->cmsg_len, cmsg->cmsg_level, cmsg->cmsg_type); | |
180 | */ | |
181 | ||
182 | check(sz==rec.contents.length); | |
183 | ||
184 | if(sz<0) | |
185 | return (int)sz; | |
186 | else | |
187 | return 0; | |
188 | } | |
189 | ||
190 | ||
191 | static | |
192 | int TLSSocket_InitPendingCiphers(SSLRecordContextRef ref, | |
193 | uint16_t selectedCipher, | |
194 | bool server, | |
195 | SSLBuffer key) | |
196 | { | |
197 | int socket = (int)ref; | |
198 | int rc; | |
199 | char *buf; | |
200 | ||
201 | buf = malloc(key.length+3); | |
202 | buf[0] = selectedCipher >> 8; | |
203 | buf[1] = selectedCipher & 0xff; | |
204 | buf[2] = server; | |
205 | memcpy(buf+3, key.data, key.length); | |
206 | ||
207 | printf("%s: cipher=%04x, keylen=%ld\n", __FUNCTION__, selectedCipher, key.length); | |
208 | ||
209 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_INIT_CIPHER, buf, (socklen_t)(key.length+3)); | |
210 | ||
211 | printf("%s: rc=%d\n", __FUNCTION__, rc); | |
212 | ||
213 | free(buf); | |
214 | ||
215 | return rc; | |
216 | } | |
217 | ||
218 | static | |
219 | int TLSSocket_AdvanceWriteCipher(SSLRecordContextRef ref) | |
220 | { | |
221 | int socket = (int)ref; | |
222 | int rc; | |
223 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_WRITE_CIPHER, NULL, 0); | |
224 | ||
225 | printf("%s: rc=%d\n", __FUNCTION__, rc); | |
226 | ||
227 | return rc; | |
228 | } | |
229 | ||
230 | static | |
231 | int TLSSocket_RollbackWriteCipher(SSLRecordContextRef ref) | |
232 | { | |
233 | int socket = (int)ref; | |
234 | int rc; | |
235 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ROLLBACK_WRITE_CIPHER, NULL, 0); | |
236 | ||
237 | printf("%s: rc=%d\n", __FUNCTION__, rc); | |
238 | ||
239 | return rc; | |
240 | } | |
241 | ||
242 | static | |
243 | int TLSSocket_AdvanceReadCipher(SSLRecordContextRef ref) | |
244 | { | |
245 | int socket = (int)ref; | |
246 | int rc; | |
247 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_ADVANCE_READ_CIPHER, NULL, 0); | |
248 | ||
249 | printf("%s: rc=%d\n", __FUNCTION__, rc); | |
250 | ||
251 | return rc; | |
252 | } | |
253 | ||
254 | static | |
255 | int TLSSocket_SetProtocolVersion(SSLRecordContextRef ref, | |
256 | SSLProtocolVersion protocolVersion) | |
257 | { | |
258 | int socket = (int)ref; | |
259 | int rc; | |
260 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_PROTOCOL_VERSION, &protocolVersion, sizeof(protocolVersion)); | |
261 | ||
262 | printf("%s: rc=%d\n", __FUNCTION__, rc); | |
263 | ||
264 | return rc; | |
265 | } | |
266 | ||
267 | ||
268 | static | |
269 | int TLSSocket_ServiceWriteQueue(SSLRecordContextRef ref) | |
270 | { | |
271 | int socket = (int)ref; | |
272 | int rc; | |
273 | rc = setsockopt(socket, SOL_SOCKET, SO_TLS_SERVICE_WRITE_QUEUE, NULL, 0); | |
274 | ||
275 | return rc; | |
276 | } | |
277 | ||
278 | ||
279 | const struct SSLRecordFuncs TLSSocket_Funcs = { | |
280 | .read = TLSSocket_Read, | |
281 | .write = TLSSocket_Write, | |
282 | .initPendingCiphers = TLSSocket_InitPendingCiphers, | |
283 | .advanceWriteCipher = TLSSocket_AdvanceWriteCipher, | |
284 | .rollbackWriteCipher = TLSSocket_RollbackWriteCipher, | |
285 | .advanceReadCipher = TLSSocket_AdvanceReadCipher, | |
286 | .setProtocolVersion = TLSSocket_SetProtocolVersion, | |
287 | .free = TLSSocket_Free, | |
288 | .serviceWriteQueue = TLSSocket_ServiceWriteQueue, | |
289 | }; | |
290 | ||
291 | ||
292 | /* TLSSocket SPIs */ | |
293 | ||
294 | int TLSSocket_Attach(int socket) | |
295 | { | |
296 | ||
297 | /* Attach the TLS socket filter and return handle */ | |
298 | struct so_nke so_tlsnke; | |
299 | int rc; | |
300 | int handle; | |
301 | socklen_t len; | |
302 | ||
303 | memset(&so_tlsnke, 0, sizeof(so_tlsnke)); | |
304 | so_tlsnke.nke_handle = TLS_HANDLE_IP4; | |
305 | rc=setsockopt(socket, SOL_SOCKET, SO_NKE, &so_tlsnke, sizeof(so_tlsnke)); | |
306 | if(rc) | |
307 | return rc; | |
308 | ||
309 | len = sizeof(handle); | |
310 | rc = getsockopt(socket, SOL_SOCKET, SO_TLS_HANDLE, &handle, &len); | |
311 | if(rc) | |
312 | return rc; | |
313 | ||
314 | assert(len==sizeof(handle)); | |
315 | ||
316 | return handle; | |
317 | } | |
318 |