
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <stdbool.h>
#include <libgen.h>
#include <unistd.h>
#include <mjpegtools/yuv4mpeg.h>

typedef uint8_t ** yuv_frame;

// stream stuff
static int fd_in, fd_out;
static y4m_stream_info_t istream, ostream;
static y4m_frame_info_t iframe;
static yuv_frame yuv_buffer = NULL;		// [0]:Y (width*height), [1]:Cb, [2]:Cr (width*height/4, each)
static int n_planes;					// number of planes, depends in input format (4:2:0, 4:4:4alpha, etc.)
static int fnum = 0;
static y4m_ratio_t framerate;
static log_level_t loglevel = 1;		// normal

int proc_frames(int num, bool skip) {
	int ret = Y4M_OK;
	if (num < 0) {
		mjpeg_info("%s the rest of the frames", (skip) ? "skipping" : "copying ");
	} else {
		mjpeg_info("%s %d frames", (skip) ? "skipping" : "copying ", num);
	}
	while (num-- && (ret = y4m_read_frame(fd_in, &istream, &iframe, yuv_buffer)) == Y4M_OK) {
		fnum++;
		if (loglevel > 1)
			mjpeg_debug("%s frame %d [temporal: %d, spatial: %d]",
					(skip) ? "skipping" : "processing", fnum,
					y4m_fi_get_temporal(&iframe), y4m_fi_get_spatial(&iframe));
		if (!skip) {
			y4m_write_frame(fd_out, &ostream, &iframe, yuv_buffer);
		}
	}
	return ret;
}

static char *program_name;

int init() {
	int res = Y4M_OK;
	fd_in = STDIN_FILENO, fd_out = STDOUT_FILENO;
	
	y4m_init_stream_info(&istream);
	y4m_init_frame_info(&iframe);
	y4m_init_stream_info(&ostream);
	
	res = y4m_read_stream_header(fd_in, &istream);
	if (res != Y4M_OK) {
		mjpeg_error("error reading stream header: %s", y4m_strerr(res));
		return res;
	}
	
	y4m_copy_stream_info(&ostream, &istream);
	res = y4m_write_stream_header(fd_out, &ostream);
	if (res != Y4M_OK) {
		mjpeg_error("error writing stream header: %s", y4m_strerr(res));
		return res;
	}
	
	n_planes = y4m_si_get_plane_count(&istream);
	yuv_buffer = malloc(sizeof(uint8_t *) * n_planes);
	int i;
	for (i=0; i<n_planes; i++) {
		yuv_buffer[i] = malloc(y4m_si_get_plane_length(&istream, i));
		if (yuv_buffer[i] == NULL) {
			perror("malloc()");
			while (i-- > 0)	{
				free(yuv_buffer[i]);
				yuv_buffer[i] = NULL;
			}
			return Y4M_ERR_SYSTEM;
		}
	}
	
 	framerate = y4m_si_get_framerate(&istream);
	return res;
}

typedef struct {
	int frames;
	int seconds;
} offset_t;

struct range {
	struct range *next;
	offset_t from, to;
	bool skip : 1;
};

static const struct range * invalid_range = NULL;

int compare_ranges(const void *p, const void *q) {
	const struct range *rp = *(const struct range **)p, *rq = *(const struct range **)q;
	if (rp->from.frames < rq->from.frames) {
		if (rp->to.frames > rq->to.frames)
			invalid_range = rq;
		return -1;
	} else if (rp->from.frames > rq->from.frames) {
		if (rp->to.frames < rq->to.frames)
			invalid_range = rp;
		return 1;
	} else {
		return ((rp->to.frames < rq->to.frames) ? -1 :
				(rp->to.frames > rq->to.frames) ? 1 :
				(!rp->skip) ? (rq->skip) ? -1 : 0 : 1);
	}
}

#define PRINT_BUF_SIZE 40
static char print_buf[PRINT_BUF_SIZE];

char * print_range(const struct range *r) {
	char *buf = print_buf;
	size_t len = PRINT_BUF_SIZE;
	int p;
	
	p = snprintf(buf, len, "[");
	buf += p; len -= p;
	
	if (r->from.seconds == 0) {
		p = snprintf(buf, len, "frame %d", r->from.frames);
	} else {
		p = snprintf(buf, len, "second %d", r->from.seconds);
	}
	buf += p; len -= p;
	
	p = snprintf(buf, len, ", ");
	buf += p; len -= p;
	
	if (r->to.seconds == 0) {
		p = snprintf(buf, len, "frame %d", r->to.frames);
	} else {
		p = snprintf(buf, len, "second %d", r->to.seconds);
	}
	buf += p; len -= p;
	
	p = snprintf(buf, len, ")");
	buf += p; len -= p;
	
	return print_buf;
}

extern char *optarg;
extern int optind, opterr, optopt;

void usage(bool);
int parse_time(char *val, offset_t *o);

int main(int argc, char **argv) {
	
	program_name = basename(*argv);
	int res = 0, i;
	
	struct range *ranges = NULL;
	struct range **r_arr = NULL;
	int n_ranges = 0;
	
	mjpeg_default_handler_identifier(program_name);
	mjpeg_default_handler_verbosity(loglevel);
	
	// process options into ranges
	int opt;
	bool absolute = false;
	offset_t olast = { 0, 0 };
	struct range *r = NULL;
	bool stop_at_end = false;
	while ((opt = getopt(argc, argv, "an:s:N:S:xhv:")) != -1) {
		switch (opt) {
			case 'a':
				absolute = true;
				break;
			
			case 'n':
			case 's':
			case 'N':
			case 'S': {
				offset_t o;
				o.frames = o.seconds = 0;
				if (opt == 'n' || opt == 's') {
					o.frames = atoi(optarg);
				} else {
					res = parse_time(optarg, &o);
				}
				if (o.frames < 0 || o.seconds < 0 || res < 0) {
					mjpeg_error("invalid argument: %s (%d)", optarg, res);
					res = EXIT_FAILURE;
					goto done;
				}
				if (!absolute) {
					o.frames += olast.frames;
					o.seconds += olast.seconds;
				}
				struct range *nr = calloc(1, sizeof(struct range));
				if (nr == NULL) {
					perror("malloc()");
					res = EXIT_FAILURE;
					goto done;
				}
				nr->skip = (opt == 's' || opt == 'S');
				nr->from = olast;
				nr->to = o;
				olast = o;
				if (ranges == NULL) {
					ranges = nr;
				} else {
					r->next = nr;
				}
				r = nr;
				n_ranges++;
				absolute = false;
			} break;
			
			case 'x':
				stop_at_end = true;
				break;
			
			case 'h':
				usage(true);
				res = EXIT_SUCCESS;
				goto done;
			
			case 'v': {
				loglevel = atoi(optarg);
				mjpeg_default_handler_verbosity(loglevel);
			} break;
			
			default:
				usage(false);
				res = EXIT_FAILURE;
				goto done;
		}
	}
	
	if (ranges == NULL) {
		mjpeg_error("no ranges given");
		usage(false);
		goto done;
	}
	
	res = init();
	if (res != Y4M_OK)
		goto done;
	
	r = ranges;
	r_arr = malloc(sizeof(struct range *) * n_ranges);
	i = 0;
	while (r != NULL) {
		// convert all ranges from seconds to frames
		int sec2fram = (r->from.seconds * framerate.n) / framerate.d;
		mjpeg_debug("[from] seconds (%d) to frames: %d+%d, using framerate %d:%d",
				r->from.seconds, sec2fram, r->from.frames, framerate.n, framerate.d);
		r->from.frames += sec2fram;
		sec2fram = (r->to.seconds * framerate.n) / framerate.d;
		mjpeg_debug("[ to ] seconds (%d) to frames: %d+%d, using framerate %d:%d",
				r->to.seconds, sec2fram, r->to.frames, framerate.n, framerate.d);
		r->to.frames += sec2fram;
		r_arr[i++] = r;
		r = r->next;
	}
	qsort(r_arr, n_ranges, sizeof(struct range *), compare_ranges);
	if (invalid_range != NULL) {
		mjpeg_error("error processing range %s", print_range(invalid_range));
		goto done;
	}
	
	// find out how many frames are to be skipped before the end -> init ring-buffer
	r = r_arr[n_ranges - 1];
	
	for (i=0; i<n_ranges; i++) {
		r_arr[i]->next = (i + i < n_ranges) ? r_arr[i + 1] : NULL;
		r_arr[i]->from.seconds = 0;
		r_arr[i]->to.seconds = 0;
	}
	
	size_t frame_size = 0;
	for (i=0; i<n_planes; i++)
		frame_size += y4m_si_get_plane_length(&istream, i);
	
	fnum = 0;
	for (i=0; i<n_ranges; i++) {
		r = r_arr[i];
		if (fnum < r->from.frames)
			proc_frames(r->from.frames - fnum, !r->skip);
		proc_frames(r->to.frames - r->from.frames, r->skip);
	}
	if (stop_at_end) {
		mjpeg_info("all ranges processed, exiting...");
	} else {
		proc_frames(-1, !r->skip);
	}
	
done:
	free(r_arr);
	r = ranges;
	while (r != NULL) {
		struct range *nr = r->next;
		free(r);
		r = nr;
	}
	
	if (yuv_buffer != NULL) {
		for (i=0; i<n_planes; i++)
			free(yuv_buffer[i]);
		free(yuv_buffer);
	}
	
	y4m_fini_frame_info(&iframe);
	y4m_fini_stream_info(&istream);
	y4m_fini_stream_info(&ostream);
	
	return res;
}

int parse_time(char *s, offset_t *o) {
	if (*s == '\0')
		return -1;
	
	long res;
	char *endptr = s - 1;
	int loop = 0;
	
dgtgroup:
	if (++loop > 3)
		return -2;
	s = endptr + 1;
	res = strtol(s, &endptr, 10);
	if (res < 0 || endptr - s < 1)
		return -3;
	switch (*endptr) {
		case '\0':
		case '.':
			o->seconds += res;
			break;
		case ':':
			o->seconds += res;
			o->seconds *= 60;
			goto dgtgroup;
		default:
			return -4;
	}
	
	if (*endptr == '.') {
		s = endptr + 1;
		if (*s == '\0')
			return -5;
		res = strtol(s, &endptr, 10);
		if (res < 0 || *endptr != '\0')
			return -6;
		o->frames = res;
	}
	return 0;
}

void usage(bool all) {
	if (all)
		fprintf(stderr, "%s copies slices of frames in the yuv4mpeg-stream\n", program_name);
	
	fprintf(stderr, "\
usage: ... |  %s [-a] { { -n | -s } FRAMES | { -N | -S } TIME } [...]  | ...\n", program_name);
	
	// TODO: from end
	if (all) {
		fputs("\n", stderr);
		fputs("Options:\n", stderr);
		fputs("  -a         next argument uses absolute addressing\n", stderr);
		fputs("  -n FRAMES  output the next FRAMES frames\n", stderr);
		fputs("  -s FRAMES  skip the next FRAMES frames\n", stderr);
		fputs("  -N TIME    output the next frames up to TIME\n", stderr);
		fputs("  -S TIME    skip the next frames up to TIME\n", stderr);
		fputs("  -x         stop execution immediatly after last range\n", stderr);
		fputs("  -h         shows this help message\n", stderr);
		fputs("  -v N       set log-level (0: quiet, 1: normal, 2:verbose/debug)\n", stderr);
		fputs("\n", stderr);
		fputs("\
TIME is a timestamp given in the following form: [[HH:]MM:]SS[.FF]; HH for\n\
hours, MM for minutes, SS for seconds and FF for frames.\n", stderr);
		fputs("\n", stderr);
		fputs("\
This program is Free Software and released under the terms of the General Public\n\
License v2 by Franz Brausse (<dev@karlchenofhell.org>) in the hope to be useful.\n", stderr);
	}
}
