+#if( TARGET_OS_DARWIN )
+//===========================================================================================================================
+// KeepAliveTestCmd
+//===========================================================================================================================
+
+typedef enum
+{
+ kKeepAliveCallVariant_Null = 0,
+ kKeepAliveCallVariant_TakesSocket = 1, // DNSServiceSleepKeepalive(), which takes a connected socket.
+ kKeepAliveCallVariant_TakesSockAddrs = 2, // DNSServiceSleepKeepalive_sockaddr(), which takes connection's sockaddrs.
+
+} KeepAliveCallVariant;
+
+typedef struct
+{
+ int family; // TCP connection's address family.
+ KeepAliveCallVariant callVariant; // Describes which DNSServiceSleepKeepalive* call to use.
+ const char * description;
+
+} KeepAliveSubtestParams;
+
+const KeepAliveSubtestParams kKeepAliveSubtestParams[] =
+{
+ { AF_INET, kKeepAliveCallVariant_TakesSocket, "Calls DNSServiceSleepKeepalive() for IPv4 TCP connection." },
+ { AF_INET, kKeepAliveCallVariant_TakesSockAddrs, "Calls DNSServiceSleepKeepalive_sockaddr() for IPv4 TCP connection." },
+ { AF_INET6, kKeepAliveCallVariant_TakesSocket, "Calls DNSServiceSleepKeepalive() for IPv6 TCP connection." },
+ { AF_INET6, kKeepAliveCallVariant_TakesSockAddrs, "Calls DNSServiceSleepKeepalive_sockaddr() for IPv6 TCP connection." }
+};
+
+typedef struct
+{
+ sockaddr_ip local; // TCP connection's local address and port.
+ sockaddr_ip remote; // TCP connection's remote address and port.
+ NanoTime64 startTime; // Subtest's start time.
+ NanoTime64 endTime; // Subtest's end time.
+ SocketRef clientSock; // Socket for client-side of TCP connection.
+ SocketRef serverSock; // Socket for server-side of TCP connection.
+ char * recordName; // Keepalive record's name.
+ char * dataStr; // Data expected to be contained in keepalive record's data.
+ const char * description; // Subtests's description.
+ unsigned int timeoutKA; // Randomly-generated timeout value that gets put in keepalive record's rdata.
+ OSStatus error; // Subtest's error.
+
+} KeepAliveSubtest;
+
+typedef struct KeepAliveTest * KeepAliveTestRef;
+
+typedef struct
+{
+ KeepAliveTestRef test; // Weak back pointer to test.
+
+} KeepAliveTestConnectionContext;
+
+struct KeepAliveTest
+{
+ dispatch_queue_t queue; // Serial queue for test events.
+ dispatch_semaphore_t doneSem; // Semaphore to signal when the test is done.
+ dispatch_source_t readSource; // Read source for TCP listener socket.
+ DNSServiceRef keepalive; // DNSServiceSleepKeepalive{,2} operation.
+ DNSServiceRef query; // Query to verify registered keepalive record.
+ dispatch_source_t timer; // Timer to put time limit on query.
+ AsyncConnectionRef connection; // Establishes current subtest's TCP connection.
+ KeepAliveTestConnectionContext * connectionCtx; // Weak pointer to connection's context.
+ NanoTime64 startTime; // Test's start time.
+ NanoTime64 endTime; // Test's end time.
+ OSStatus error; // Test's error.
+ size_t subtestIdx; // Index of current subtest.
+ KeepAliveSubtest subtests[ 4 ]; // Subtest array.
+};
+check_compile_time( countof_field( struct KeepAliveTest, subtests ) == countof( kKeepAliveSubtestParams ) );
+
+ulog_define_ex( kDNSSDUtilIdentifier, KeepAliveTest, kLogLevelInfo, kLogFlags_None, "KeepAliveTest", NULL );
+#define kat_ulog( LEVEL, ... ) ulog( &log_category_from_name( KeepAliveTest ), (LEVEL), __VA_ARGS__ )
+
+static OSStatus _KeepAliveTestCreate( KeepAliveTestRef *outTest );
+static OSStatus _KeepAliveTestRun( KeepAliveTestRef inTest );
+static void _KeepAliveTestFree( KeepAliveTestRef inTest );
+
+static void KeepAliveTestCmd( void )
+{
+ OSStatus err;
+ OutputFormatType outputFormat;
+ KeepAliveTestRef test = NULL;
+ CFPropertyListRef plist = NULL;
+ CFMutableArrayRef subtests;
+ size_t i;
+ size_t subtestFailCount;
+ Boolean testPassed = false;
+ char startTime[ 32 ];
+ char endTime[ 32 ];
+
+ err = OutputFormatFromArgString( gKeepAliveTest_OutputFormat, &outputFormat );
+ require_noerr_quiet( err, exit );
+
+ err = _KeepAliveTestCreate( &test );
+ require_noerr( err, exit );
+
+ err = _KeepAliveTestRun( test );
+ require_noerr( err, exit );
+
+ _NanoTime64ToTimestamp( test->startTime, startTime, sizeof( startTime ) );
+ _NanoTime64ToTimestamp( test->endTime, endTime, sizeof( endTime ) );
+ err = CFPropertyListCreateFormatted( kCFAllocatorDefault, &plist,
+ "{"
+ "%kO=%s" // startTime
+ "%kO=%s" // endTime
+ "%kO=[%@]" // subtests
+ "}",
+ CFSTR( "startTime" ), startTime,
+ CFSTR( "endTime" ), endTime,
+ CFSTR( "subtests" ), &subtests );
+ require_noerr( err, exit );
+
+ subtestFailCount = 0;
+ check( test->subtestIdx == countof( test->subtests ) );
+ for( i = 0; i < countof( test->subtests ); ++i )
+ {
+ KeepAliveSubtest * const subtest = &test->subtests[ i ];
+ char errorDesc[ 128 ];
+
+ _NanoTime64ToTimestamp( subtest->startTime, startTime, sizeof( startTime ) );
+ _NanoTime64ToTimestamp( subtest->endTime, endTime, sizeof( endTime ) );
+ SNPrintF( errorDesc, sizeof( errorDesc ), "%m", subtest->error );
+ err = CFPropertyListAppendFormatted( kCFAllocatorDefault, subtests,
+ "{"
+ "%kO=%s" // startTime
+ "%kO=%s" // endTime
+ "%kO=%s" // description
+ "%kO=%##a" // localAddr
+ "%kO=%##a" // remoteAddr
+ "%kO=%s" // recordName
+ "%kO=%s" // expectedRData
+ "%kO=" // error
+ "{"
+ "%kO=%lli" // code
+ "%kO=%s" // description
+ "}"
+ "}",
+ CFSTR( "startTime" ), startTime,
+ CFSTR( "endTime" ), endTime,
+ CFSTR( "description" ), subtest->description,
+ CFSTR( "localAddr" ), &subtest->local.sa,
+ CFSTR( "remoteAddr" ), &subtest->remote.sa,
+ CFSTR( "recordName" ), subtest->recordName,
+ CFSTR( "expectedRData" ), subtest->dataStr,
+ CFSTR( "error" ),
+ CFSTR( "code" ), (int64_t) subtest->error,
+ CFSTR( "description" ), errorDesc
+ );
+ require_noerr( err, exit );
+ if( subtest->error ) ++subtestFailCount;
+ }
+ if( subtestFailCount == 0 ) testPassed = true;
+ CFPropertyListAppendFormatted( kCFAllocatorDefault, plist, "%kO=%b", CFSTR( "pass" ), testPassed );
+
+ err = OutputPropertyList( plist, outputFormat, gKeepAliveTest_OutputFilePath );
+ require_noerr( err, exit );
+
+exit:
+ if( test ) _KeepAliveTestFree( test );
+ CFReleaseNullSafe( plist );
+ gExitCode = err ? 1 : ( testPassed ? 0 : 2 );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestStart( void *inCtx );
+static void _KeepAliveTestStop( KeepAliveTestRef inTest, OSStatus inError );
+static OSStatus _KeepAliveTestStartSubtest( KeepAliveTestRef inTest );
+static void _KeepAliveTestStopSubtest( KeepAliveTestRef inTest );
+static KeepAliveSubtest * _KeepAliveTestGetSubtest( KeepAliveTestRef inTest );
+static const char * _KeepAliveTestGetSubtestLogPrefix( KeepAliveTestRef inTest, char *inBufPtr, size_t inBufLen );
+static OSStatus _KeepAliveTestContinue( KeepAliveTestRef inTest, OSStatus inSubtestError, Boolean *outDone );
+static void _KeepAliveTestTCPAcceptHandler( void *inCtx );
+static void _KeepAliveTestConnectionHandler( SocketRef inSock, OSStatus inError, void *inArg );
+static void _KeepAliveTestHandleConnection( KeepAliveTestRef inTest, SocketRef inSock, OSStatus inError );
+static void _KeepAliveTestForgetConnection( KeepAliveTestRef inTest );
+static void DNSSD_API _KeepAliveTestKeepaliveCallback( DNSServiceRef inSDRef, DNSServiceErrorType inErr, void *inCtx );
+static void _KeepAliveTestQueryTimerHandler( void *inCtx );
+static void DNSSD_API
+ _KeepAliveTestQueryRecordCallback(
+ DNSServiceRef inSDRef,
+ DNSServiceFlags inFlags,
+ uint32_t inInterfaceIndex,
+ DNSServiceErrorType inError,
+ const char * inFullName,
+ uint16_t inType,
+ uint16_t inClass,
+ uint16_t inRDataLen,
+ const void * inRDataPtr,
+ uint32_t inTTL,
+ void * inCtx );
+
+static OSStatus _KeepAliveTestCreate( KeepAliveTestRef *outTest )
+{
+ OSStatus err;
+ KeepAliveTestRef test;
+ size_t i;
+
+ test = (KeepAliveTestRef) calloc( 1, sizeof( *test ) );
+ require_action( test, exit, err = kNoMemoryErr );
+
+ test->error = kInProgressErr;
+ for( i = 0; i < countof( test->subtests ); ++i )
+ {
+ KeepAliveSubtest * const subtest = &test->subtests[ i ];
+
+ subtest->local.sa.sa_family = AF_UNSPEC;
+ subtest->remote.sa.sa_family = AF_UNSPEC;
+ subtest->clientSock = kInvalidSocketRef;
+ subtest->serverSock = kInvalidSocketRef;
+ }
+ test->queue = dispatch_queue_create( "com.apple.dnssdutil.keepalive-test", DISPATCH_QUEUE_SERIAL );
+ require_action( test->queue, exit, err = kNoResourcesErr );
+
+ test->doneSem = dispatch_semaphore_create( 0 );
+ require_action( test->doneSem, exit, err = kNoResourcesErr );
+
+ *outTest = test;
+ test = NULL;
+ err = kNoErr;
+
+exit:
+ if( test ) _KeepAliveTestFree( test );
+ return( err );
+}
+
+//===========================================================================================================================
+
+static OSStatus _KeepAliveTestRun( KeepAliveTestRef inTest )
+{
+ dispatch_async_f( inTest->queue, inTest, _KeepAliveTestStart );
+ dispatch_semaphore_wait( inTest->doneSem, DISPATCH_TIME_FOREVER );
+ return( inTest->error );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestFree( KeepAliveTestRef inTest )
+{
+ size_t i;
+
+ check( !inTest->readSource );
+ check( !inTest->query );
+ check( !inTest->timer );
+ check( !inTest->keepalive );
+ check( !inTest->connection );
+ check( !inTest->connectionCtx );
+ dispatch_forget( &inTest->queue );
+ dispatch_forget( &inTest->doneSem );
+ for( i = 0; i < countof( inTest->subtests ); ++i )
+ {
+ KeepAliveSubtest * const subtest = &inTest->subtests[ i ];
+
+ check( !IsValidSocket( subtest->clientSock ) );
+ check( !IsValidSocket( subtest->serverSock ) );
+ ForgetMem( &subtest->recordName );
+ ForgetMem( &subtest->dataStr );
+ }
+ free( inTest );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestStart( void *inCtx )
+{
+ OSStatus err;
+ const KeepAliveTestRef test = (KeepAliveTestRef) inCtx;
+
+ test->error = kInProgressErr;
+ test->startTime = NanoTimeGetCurrent();
+ err = _KeepAliveTestStartSubtest( test );
+ require_noerr( err, exit );
+
+exit:
+ if( err ) _KeepAliveTestStop( test, err );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestStop( KeepAliveTestRef inTest, OSStatus inError )
+{
+ size_t i;
+
+ inTest->error = inError;
+ inTest->endTime = NanoTimeGetCurrent();
+ _KeepAliveTestStopSubtest( inTest );
+ for( i = 0; i < countof( inTest->subtests ); ++i )
+ {
+ KeepAliveSubtest * const subtest = &inTest->subtests[ i ];
+
+ ForgetSocket( &subtest->clientSock );
+ ForgetSocket( &subtest->serverSock );
+ }
+ dispatch_semaphore_signal( inTest->doneSem );
+}
+
+//===========================================================================================================================
+
+static OSStatus _KeepAliveTestStartSubtest( KeepAliveTestRef inTest )
+{
+ OSStatus err;
+ KeepAliveSubtest * const subtest = _KeepAliveTestGetSubtest( inTest );
+ const KeepAliveSubtestParams * const params = &kKeepAliveSubtestParams[ inTest->subtestIdx ];
+ int port;
+ SocketRef sock = kInvalidSocketRef;
+ const uint32_t loopbackV4 = htonl( INADDR_LOOPBACK );
+ SocketContext * sockCtx = NULL;
+ KeepAliveTestConnectionContext * cnxCtx = NULL;
+ Boolean useIPv4;
+ char serverAddrStr[ 64 ];
+ char prefix[ 64 ];
+
+ subtest->error = kInProgressErr;
+ subtest->startTime = NanoTimeGetCurrent();
+ subtest->description = params->description;
+
+ require_action( ( params->family == AF_INET ) || ( params->family == AF_INET6 ), exit, err = kInternalErr );
+
+ // Create TCP listener socket.
+
+ useIPv4 = ( params->family == AF_INET ) ? true : false;
+ err = ServerSocketOpenEx( params->family, SOCK_STREAM, IPPROTO_TCP,
+ useIPv4 ? ( (const void *) &loopbackV4 ) : ( (const void *) &in6addr_loopback ), kSocketPort_Auto, &port,
+ kSocketBufferSize_DontSet, &sock );
+ require_noerr( err, exit );
+
+ if( useIPv4 ) SNPrintF( serverAddrStr, sizeof( serverAddrStr ), "%.4a:%d", &loopbackV4, port );
+ else SNPrintF( serverAddrStr, sizeof( serverAddrStr ), "[%.16a]:%d", in6addr_loopback.s6_addr, port );
+ _KeepAliveTestGetSubtestLogPrefix( inTest, prefix, sizeof( prefix ) );
+ kat_ulog( kLogLevelInfo, "%s: Will listen for connections on %s\n", prefix, serverAddrStr );
+
+ err = SocketContextCreate( sock, inTest, &sockCtx );
+ require_noerr( err, exit );
+ sock = kInvalidSocketRef;
+
+ // Create read source for TCP listener socket.
+
+ check( !inTest->readSource );
+ err = DispatchReadSourceCreate( sockCtx->sock, inTest->queue, _KeepAliveTestTCPAcceptHandler,
+ SocketContextCancelHandler, sockCtx, &inTest->readSource );
+ require_noerr( err, exit );
+ sockCtx = NULL;
+ dispatch_resume( inTest->readSource );
+
+ cnxCtx = (KeepAliveTestConnectionContext *) calloc( 1, sizeof( *cnxCtx ) );
+ require_action( cnxCtx, exit, err = kNoMemoryErr );
+
+ // Start asynchronous connection to listener socket.
+
+ kat_ulog( kLogLevelInfo, "%s: Will connect to %s\n", prefix, serverAddrStr );
+
+ check( !inTest->connection );
+ err = AsyncConnection_Connect( &inTest->connection, serverAddrStr, 0, kAsyncConnectionFlags_None,
+ 5 * UINT64_C_safe( kNanosecondsPerSecond ), kSocketBufferSize_DontSet, kSocketBufferSize_DontSet,
+ NULL, NULL, _KeepAliveTestConnectionHandler, cnxCtx, inTest->queue );
+ require_noerr( err, exit );
+
+ cnxCtx->test = inTest;
+ check( !inTest->connectionCtx );
+ inTest->connectionCtx = cnxCtx;
+ cnxCtx = NULL;
+
+exit:
+ ForgetSocket( &sock );
+ if( sockCtx ) SocketContextRelease( sockCtx );
+ FreeNullSafe( cnxCtx );
+ return( err );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestStopSubtest( KeepAliveTestRef inTest )
+{
+ dispatch_source_forget( &inTest->readSource );
+ DNSServiceForget( &inTest->keepalive );
+ DNSServiceForget( &inTest->query );
+ dispatch_source_forget( &inTest->timer );
+ _KeepAliveTestForgetConnection( inTest );
+}
+
+//===========================================================================================================================
+
+static KeepAliveSubtest * _KeepAliveTestGetSubtest( KeepAliveTestRef inTest )
+{
+ return( ( inTest->subtestIdx < countof( inTest->subtests ) ) ? &inTest->subtests[ inTest->subtestIdx ] : NULL );
+}
+
+//===========================================================================================================================
+
+static const char * _KeepAliveTestGetSubtestLogPrefix( KeepAliveTestRef inTest, char *inBufPtr, size_t inBufLen )
+{
+ SNPrintF( inBufPtr, inBufLen, "Subtest %zu/%zu", inTest->subtestIdx + 1, countof( inTest->subtests ) );
+ return( inBufPtr );
+}
+
+//===========================================================================================================================
+
+static OSStatus _KeepAliveTestContinue( KeepAliveTestRef inTest, OSStatus inSubtestError, Boolean *outDone )
+{
+ OSStatus err;
+ KeepAliveSubtest * subtest;
+
+ require_action( inTest->subtestIdx <= countof( inTest->subtests ), exit, err = kInternalErr );
+
+ if( inTest->subtestIdx < countof( inTest->subtests ) )
+ {
+ subtest = _KeepAliveTestGetSubtest( inTest );
+ _KeepAliveTestStopSubtest( inTest );
+ subtest->endTime = NanoTimeGetCurrent();
+ subtest->error = inSubtestError;
+ if( ++inTest->subtestIdx < countof( inTest->subtests ) )
+ {
+ err = _KeepAliveTestStartSubtest( inTest );
+ require_noerr_quiet( err, exit );
+ }
+ }
+ err = kNoErr;
+
+exit:
+ if( outDone ) *outDone = ( !err && ( inTest->subtestIdx == countof( inTest->subtests ) ) ) ? true : false;
+ return( err );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestTCPAcceptHandler( void *inCtx )
+{
+ OSStatus err;
+ const SocketContext * const sockCtx = (SocketContext *) inCtx;
+ const KeepAliveTestRef test = (KeepAliveTestRef) sockCtx->userContext;
+ KeepAliveSubtest * const subtest = _KeepAliveTestGetSubtest( test );
+ sockaddr_ip peer;
+ socklen_t len;
+ char prefix[ 64 ];
+
+ check( !IsValidSocket( subtest->serverSock ) );
+ len = (socklen_t) sizeof( peer );
+ subtest->serverSock = accept( sockCtx->sock, &peer.sa, &len );
+ err = map_socket_creation_errno( subtest->serverSock );
+ require_noerr( err, exit );
+
+ _KeepAliveTestGetSubtestLogPrefix( test, prefix, sizeof( prefix ) );
+ kat_ulog( kLogLevelInfo, "%s: Accepted connection from %##a\n", prefix, &peer.sa );
+
+ dispatch_source_forget( &test->readSource );
+
+exit:
+ if( err ) _KeepAliveTestStop( test, err );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestConnectionHandler( SocketRef inSock, OSStatus inError, void *inArg )
+{
+ KeepAliveTestConnectionContext * ctx = (KeepAliveTestConnectionContext *) inArg;
+ const KeepAliveTestRef test = ctx->test;
+
+ if( test )
+ {
+ _KeepAliveTestForgetConnection( test );
+ _KeepAliveTestHandleConnection( test, inSock, inError );
+ inSock = kInvalidSocketRef;
+ }
+ ForgetSocket( &inSock );
+ free( ctx );
+}
+
+//===========================================================================================================================
+
+#define kKeepAliveTestQueryTimeoutSecs 5
+
+static void _KeepAliveTestHandleConnection( KeepAliveTestRef inTest, SocketRef inSock, OSStatus inError )
+{
+ OSStatus err;
+ KeepAliveSubtest * const subtest = _KeepAliveTestGetSubtest( inTest );
+ const KeepAliveSubtestParams * const params = &kKeepAliveSubtestParams[ inTest->subtestIdx ];
+ socklen_t len;
+ uint32_t value;
+ int family, i;
+ Boolean subtestFailed = false;
+ Boolean done;
+ char prefix[ 64 ];
+
+ require_noerr_action( inError, exit, err = inError );
+
+ check( !IsValidSocket( subtest->clientSock ) );
+ subtest->clientSock = inSock;
+ inSock = kInvalidSocketRef;
+
+ // Get local and remote IP addresses.
+
+ len = (socklen_t) sizeof( subtest->local );
+ err = getsockname( subtest->clientSock, &subtest->local.sa, &len );
+ err = map_global_noerr_errno( err );
+ require_noerr( err, exit );
+
+ len = (socklen_t) sizeof( subtest->remote );
+ err = getpeername( subtest->clientSock, &subtest->remote.sa, &len );
+ err = map_global_noerr_errno( err );
+ require_noerr( err, exit );
+
+ _KeepAliveTestGetSubtestLogPrefix( inTest, prefix, sizeof( prefix ) );
+ kat_ulog( kLogLevelInfo, "%s: Connection established: %##a <-> %##a\n",
+ prefix, &subtest->local.sa, &subtest->remote.sa );
+
+ // Call either DNSServiceSleepKeepalive() or DNSServiceSleepKeepalive_sockaddr().
+
+ check( subtest->timeoutKA == 0 );
+ subtest->timeoutKA = (unsigned int) RandomRange( 1, UINT_MAX );
+
+ switch( params->callVariant )
+ {
+ case kKeepAliveCallVariant_TakesSocket:
+ kat_ulog( kLogLevelInfo, "%s: Will call DNSServiceSleepKeepalive() for client-side socket\n", prefix );
+ check( !inTest->keepalive );
+ err = DNSServiceSleepKeepalive( &inTest->keepalive, 0, subtest->clientSock,
+ subtest->timeoutKA, _KeepAliveTestKeepaliveCallback, inTest );
+ require_noerr( err, exit );
+
+ err = DNSServiceSetDispatchQueue( inTest->keepalive, inTest->queue );
+ require_noerr( err, exit );
+ break;
+
+ case kKeepAliveCallVariant_TakesSockAddrs:
+ kat_ulog( kLogLevelInfo,
+ "%s: Will call DNSServiceSleepKeepalive_sockaddr() for local and remote sockaddrs\n", prefix );
+ if( !SOFT_LINK_HAS_FUNCTION( system_dnssd, DNSServiceSleepKeepalive_sockaddr ) )
+ {
+ kat_ulog( kLogLevelError,
+ "%s: Failed to soft link DNSServiceSleepKeepalive_sockaddr from libsystem_dnssd.\n", prefix );
+ subtestFailed = true;
+ err = kUnsupportedErr;
+ goto exit;
+ }
+ check( !inTest->keepalive );
+ err = soft_DNSServiceSleepKeepalive_sockaddr( &inTest->keepalive, 0, &subtest->local.sa, &subtest->remote.sa,
+ subtest->timeoutKA, _KeepAliveTestKeepaliveCallback, inTest );
+ require_noerr( err, exit );
+
+ err = DNSServiceSetDispatchQueue( inTest->keepalive, inTest->queue );
+ require_noerr( err, exit );
+ break;
+
+ default:
+ kat_ulog( kLogLevelError, "%s: Invalid KeepAliveCallVariant value %d\n", prefix, (int) params->callVariant );
+ err = kInternalErr;
+ goto exit;
+ }
+ // Use the same logic that the DNSServiceSleepKeepalive functions use to derive a record name and rdata.
+
+ value = 0;
+ family = subtest->local.sa.sa_family;
+ if( family == AF_INET )
+ {
+ const struct sockaddr_in * const sin = &subtest->local.v4;
+ const uint8_t * ptr;
+
+ check_compile_time_code( sizeof( sin->sin_addr.s_addr ) == 4 );
+ ptr = (const uint8_t *) &sin->sin_addr.s_addr;
+ for( i = 0; i < 4; ++i ) value += ptr[ i ];
+ value += sin->sin_port; // Note: No ntohl(). This is what DNSServiceSleepKeepalive does.
+
+ check( subtest->remote.sa.sa_family == AF_INET );
+ ASPrintF( &subtest->dataStr, "t=%u h=%.4a d=%.4a l=%u r=%u",
+ subtest->timeoutKA, &subtest->local.v4.sin_addr.s_addr, &subtest->remote.v4.sin_addr.s_addr,
+ ntohs( subtest->local.v4.sin_port ), ntohs( subtest->remote.v4.sin_port ) );
+ require_action( subtest->dataStr, exit, err = kNoMemoryErr );
+ }
+ else if( family == AF_INET6 )
+ {
+ const struct sockaddr_in6 * const sin6 = &subtest->local.v6;
+
+ check_compile_time_code( countof( sin6->sin6_addr.s6_addr ) == 16 );
+ for( i = 0; i < 16; ++i ) value += sin6->sin6_addr.s6_addr[ i ];
+ value += sin6->sin6_port; // Note: No ntohl(). This is what DNSServiceSleepKeepalive does.
+
+ check( subtest->remote.sa.sa_family == AF_INET6 );
+ ASPrintF( &subtest->dataStr, "t=%u H=%.16a D=%.16a l=%u r=%u",
+ subtest->timeoutKA, subtest->local.v6.sin6_addr.s6_addr, subtest->remote.v6.sin6_addr.s6_addr,
+ ntohs( subtest->local.v6.sin6_port ), ntohs( subtest->remote.v6.sin6_port ) );
+ require_action( subtest->dataStr, exit, err = kNoMemoryErr );
+ }
+ else
+ {
+ kat_ulog( kLogLevelError, "%s: Unexpected local address family %d\n", prefix, family );
+ err = kInternalErr;
+ goto exit;
+ }
+
+ // Start query for the new keepalive record.
+
+ check( !subtest->recordName );
+ ASPrintF( &subtest->recordName, "%u._keepalive._dns-sd._udp.local.", value );
+ require_action( subtest->recordName, exit, err = kNoMemoryErr );
+
+ kat_ulog( kLogLevelInfo, "%s: Will query for %s NULL record\n", prefix, subtest->recordName );
+ check( !inTest->query );
+ err = DNSServiceQueryRecord( &inTest->query, 0, kDNSServiceInterfaceIndexLocalOnly, subtest->recordName,
+ kDNSServiceType_NULL, kDNSServiceClass_IN, _KeepAliveTestQueryRecordCallback, inTest );
+ require_noerr( err, exit );
+
+ err = DNSServiceSetDispatchQueue( inTest->query, inTest->queue );
+ require_noerr( err, exit );
+
+ // Start timer to enforce a time limit on the query.
+
+ check( !inTest->timer );
+ err = DispatchTimerOneShotCreate( dispatch_time_seconds( kKeepAliveTestQueryTimeoutSecs ),
+ kKeepAliveTestQueryTimeoutSecs * ( INT64_C_safe( kNanosecondsPerSecond ) / 20 ), inTest->queue,
+ _KeepAliveTestQueryTimerHandler, inTest, &inTest->timer );
+ require_noerr( err, exit );
+ dispatch_resume( inTest->timer );
+
+exit:
+ ForgetSocket( &inSock );
+ if( subtestFailed )
+ {
+ err = _KeepAliveTestContinue( inTest, err, &done );
+ check_noerr( err );
+ }
+ else
+ {
+ done = false;
+ }
+ if( err || done ) _KeepAliveTestStop( inTest, err );
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestForgetConnection( KeepAliveTestRef inTest )
+{
+ if( inTest->connection )
+ {
+ check( inTest->connectionCtx );
+ inTest->connectionCtx->test = NULL; // Unset the connection's back pointer to test.
+ inTest->connectionCtx = NULL; // Context will be freed by the connection's handler.
+ AsyncConnection_Forget( &inTest->connection );
+ }
+}
+
+//===========================================================================================================================
+
+static void DNSSD_API _KeepAliveTestKeepaliveCallback( DNSServiceRef inSDRef, DNSServiceErrorType inError, void *inCtx )
+{
+ OSStatus err;
+ const KeepAliveTestRef test = (KeepAliveTestRef) inCtx;
+ char prefix[ 64 ];
+
+ Unused( inSDRef );
+
+ _KeepAliveTestGetSubtestLogPrefix( test, prefix, sizeof( prefix ) );
+ kat_ulog( kLogLevelInfo, "%s: Keepalive callback error: %#m\n", prefix, inError );
+
+ if( inError )
+ {
+ Boolean done;
+
+ err = _KeepAliveTestContinue( test, inError, &done );
+ check_noerr( err );
+ if( err || done ) _KeepAliveTestStop( test, err );
+ }
+}
+
+//===========================================================================================================================
+
+static void _KeepAliveTestQueryTimerHandler( void *inCtx )
+{
+ OSStatus err;
+ const KeepAliveTestRef test = (KeepAliveTestRef) inCtx;
+ KeepAliveSubtest * const subtest = _KeepAliveTestGetSubtest( test );
+ Boolean done;
+ char prefix[ 64 ];
+
+ _KeepAliveTestGetSubtestLogPrefix( test, prefix, sizeof( prefix ) );
+ kat_ulog( kLogLevelInfo, "%s: Query for \"%s\" timed out.\n", prefix, subtest->recordName );
+
+ err = _KeepAliveTestContinue( test, kTimeoutErr, &done );
+ check_noerr( err );
+ if( err || done ) _KeepAliveTestStop( test, err );
+}
+
+//===========================================================================================================================
+
+static void DNSSD_API
+ _KeepAliveTestQueryRecordCallback(
+ DNSServiceRef inSDRef,
+ DNSServiceFlags inFlags,
+ uint32_t inInterfaceIndex,
+ DNSServiceErrorType inError,
+ const char * inFullName,
+ uint16_t inType,
+ uint16_t inClass,
+ uint16_t inRDataLen,
+ const void * inRDataPtr,
+ uint32_t inTTL,
+ void * inCtx )
+{
+ OSStatus err;
+ const KeepAliveTestRef test = (KeepAliveTestRef) inCtx;
+ KeepAliveSubtest * const subtest = _KeepAliveTestGetSubtest( test );
+ const uint8_t * ptr;
+ size_t dataStrLen, minLen;
+ Boolean done;
+ char prefix[ 64 ];
+
+ Unused( inSDRef );
+ Unused( inInterfaceIndex );
+ Unused( inTTL );
+
+ _KeepAliveTestGetSubtestLogPrefix( test, prefix, sizeof( prefix ) );
+ if( strcasecmp( inFullName, subtest->recordName ) != 0 )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: Got unexpected record name \"%s\".\n",
+ prefix, subtest->recordName, inFullName );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ if( inType != kDNSServiceType_NULL )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: Got unexpected record type %d (%s) != %d (NULL).\n",
+ prefix, subtest->recordName, inType, RecordTypeToString( inType ), kDNSServiceType_NULL );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ if( inClass != kDNSServiceClass_IN )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: Got unexpected record class %d != %d (IN).\n",
+ prefix, subtest->recordName, inClass, kDNSServiceClass_IN );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ if( inError )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: Got unexpected error %#m.\n",
+ prefix, subtest->recordName, inError );
+ err = inError;
+ goto exit;
+ }
+ if( ( inFlags & kDNSServiceFlagsAdd ) == 0 )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: Missing Add flag.\n", prefix, subtest->recordName );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ kat_ulog( kLogLevelInfo, "%s: QueryRecord(%s) result rdata: %#H\n",
+ prefix, subtest->recordName, inRDataPtr, inRDataLen, inRDataLen );
+
+ dataStrLen = strlen( subtest->dataStr ) + 1; // There's a NUL terminator at the end of the rdata.
+ minLen = 1 + dataStrLen; // The first byte of the rdata is a length byte.
+ if( inRDataLen < minLen )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: rdata length (%d) is less than expected minimum (%zu).\n",
+ prefix, subtest->recordName, inRDataLen, minLen );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ ptr = (const uint8_t *) inRDataPtr;
+ if( ptr[ 0 ] < dataStrLen )
+ {
+ kat_ulog( kLogLevelError,
+ "%s: QueryRecord(%s) result: rdata length byte value (%d) is less than expected minimum (%zu).\n",
+ prefix, subtest->recordName, ptr[ 0 ], dataStrLen );
+ err = kUnexpectedErr;
+ goto exit;
+ }
+ if( memcmp( &ptr[ 1 ], subtest->dataStr, dataStrLen - 1 ) != 0 )
+ {
+ kat_ulog( kLogLevelError, "%s: QueryRecord(%s) result: rdata body doesn't contain '%s'.\n",
+ prefix, subtest->recordName, subtest->dataStr );
+ }
+ err = kNoErr;
+
+exit:
+ err = _KeepAliveTestContinue( test, err, &done );
+ check_noerr( err );
+ if( err || done ) _KeepAliveTestStop( test, kNoErr );
+}
+#endif // TARGET_OS_DARWIN
+