Public
Snippet #3 · created by Georg GH. Hopp ·

socketpair2

socketpair2.c
#define _POSIX_C_SOURCE 199309L

#include <string.h>
#include <stdio.h>
#include <unistd.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <errno.h>

#include <sys/epoll.h>

void
errorExit(const char * const message)
{
	if (errno != 0) {
		fprintf(stderr, "%s: %s\n", strerror(errno), message);
	} else {
		fprintf(stderr, "%s\n", message);
	}
	exit(EXIT_FAILURE);
}

void
childCode(int socket)
{
	struct msghdr    msg;
	char             cmsgbuf[CMSG_SPACE(sizeof(int))];
	struct cmsghdr * cmsg;
	int              fd = 0;
	FILE           * handle;
	char             buffer[1024];
	int              num = 0;

#define MAX_EVENTS 1
	struct epoll_event ev, events[MAX_EVENTS];
	int nfds, epollfd;

	int flags = fcntl(socket, F_GETFL, 0);
	fcntl(socket, F_SETFL, flags | O_NONBLOCK);

	epollfd = epoll_create1(0);
	if (epollfd == -1) {
		errorExit("failed to create epoll fd");
	}

	ev.events = EPOLLIN;
	ev.data.fd = socket;
	if (epoll_ctl(epollfd, EPOLL_CTL_ADD, socket, &ev) == -1) {
		errorExit("failed to add socket to epoll poll");
	}

	while (num++ < 2) {
		nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1);

		if (nfds == -1) {
			errorExit("failed in epoll_wait");
		}

		memset(&msg, 0, sizeof(msg));
		msg.msg_control    = cmsgbuf;
		msg.msg_controllen = sizeof(cmsgbuf);

		if (0 > recvmsg(events[0].data.fd, &msg, 0)) {
			errorExit("error receiving fd");
		}

		cmsg = CMSG_FIRSTHDR(&msg);
		if (cmsg == NULL || cmsg->cmsg_type != SCM_RIGHTS) {
			errorExit(
					"The first control structure contains "
					"no file descriptor.");
		}

		memcpy(&fd, CMSG_DATA(cmsg), sizeof(fd));
		handle = fdopen(fd, "r");
		if(NULL == fgets(buffer, 1024, handle)) {
			errorExit("child fails to read from retrieved handle");
		}
		close(fd);

		printf("%s", buffer);
	}
}

void
parentCode(int socket)
{
	struct msghdr    msg;
	struct cmsghdr * cmsg;
	char             cmsgbuf[CMSG_SPACE(sizeof(int))];
	int              fd = open("dummy", O_RDONLY);

	if (0 > fd) {
		errorExit("error opening testfile: dummy\n");
	}

	memset(&msg, 0, sizeof(msg));
	msg.msg_control    = cmsgbuf;
	msg.msg_controllen = sizeof(cmsgbuf); // necessary for CMSG_FIRSTHDR to
	                                      // return the correct value
	cmsg = CMSG_FIRSTHDR(&msg);
	cmsg->cmsg_level = SOL_SOCKET;
	cmsg->cmsg_type  = SCM_RIGHTS;
	cmsg->cmsg_len   = CMSG_LEN(sizeof(int));
	memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));
	msg.msg_controllen = cmsg->cmsg_len; // total size of all control blocks

	if(0 > sendmsg(socket, &msg, 0)) {
		errorExit("error sending descriptor");
	}

	// send it twice to get the child to exit
	close(fd);
	fd = open("dummy", O_RDONLY);
	memcpy(CMSG_DATA(cmsg), &fd, sizeof(int));

	if(0 > sendmsg(socket, &msg, 0)) {
		errorExit("error sending descriptor");
	}

	close(fd);
}

int
main(int argc, char * argv[])
{
	int           socket[2];
	pid_t         pid;

	if (0 > socketpair(AF_UNIX, SOCK_DGRAM, 0, socket)) {
		errorExit("error creating socketpair");
	}

	switch (pid = fork()) {
		case -1:
			perror("error in fork");
			exit(EXIT_FAILURE);

		case 0:
			close(socket[0]);
			childCode(socket[1]);
			close(socket[1]);

			break;

		default:
			close(socket[1]);
			parentCode(socket[0]);
			close(socket[0]);
			wait(NULL);
	}

	return 0;
}

// vim: set ts=4 sw=4:
Please register or login to post a comment