/*
   mergemem.c by Marnix Coppens

   This file is part of mergemem by Philipp Richter & Philipp Reisner

   mergemem is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   mergemem is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with mergemem; see the file COPYING.  If not, write to
   the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.
 */

#include "mergemem.h"
#include "merge_utils.h"
#include "../mergemod/mergemod.h"

#include <dirent.h>
#include <signal.h>
#include <asm/page.h>
#include <sys/ioctl.h>
#include <sys/resource.h>
#include <linux/kdev_t.h>

/*-----------------------------------------------------------------------------------------------*/

#if ((~0UL) == 0xffffffff)
#define PTRFMT "%08lx"
#define	STACK_TOP	0xC0000000UL
#else
#define PTRFMT "%016lx"
#define	STACK_TOP	0xC0000000UL	/* FIXME: value for 64 bits machines? */
#endif

enum chksum_valid
{
    CHKSUM_UNKNOWN = 0, CHKSUM_SUCCESS, CHKSUM_FAILED
};

CmdList    *cmd_list;		/* built once when parsing the command line */
CmdList    *ign_list;		/* built once when parsing the command line */
LibList    *lib_list;		/* built once when parsing the command line */
PidList    *pid_list;		/* built once when parsing the command line */
MrgList    *mrg_list;		/* rebuilt for every merge */

FILE       *flog;
int         loglevel = 1;
int         interval = 60;
int         nicelevel = 18;
int         mergemod_fd = -1;
char        mergeall = FALSE;
char        run_as_daemon = FALSE;
char        logfile[PATH_MAX] = LOGFILE;

typedef struct mergemem_mmem merge_mem;
typedef struct mergemem_chksum merge_chk;
typedef struct merge_stat merge_stat;

struct merge_stat
{
    int         n_shared;
    int         n_already;
    int         n_error;
    int         n_merged;
    int         n_notmerged;
    int         n_chksum1, n_chksum2;
};

/*===============================================================================================*/

static void
merge_anon_maps(MapList *mapl1, MapList *mapl2, pid_t pid1, pid_t pid2, merge_stat * pms)
{
    int         idx;
    ulong       start1, start2;
    merge_mem   mm_mem;
    merge_chk   mm_chk;

    mm_mem.pid1 = pid1;
    mm_mem.pid2 = pid2;
    start1 = mapl1->vmstart;
    start2 = mapl2->vmstart;
    idx = (mapl1->vmend - start1) / PAGE_SIZE;	/* already checked: sizes are equal for both */

    logmsg(8, "merge_anon_maps: pid(%d, %d), addr(" PTRFMT ", " PTRFMT "), size %dp\n",
	   pid1, pid2, start1, start2, idx);

    while (--idx >= 0)
    {
	/* we now cache the checksum and its validity for every page */
	if (mapl1->chksumvalid[idx] == CHKSUM_FAILED)
	{
	    pms->n_error++;
	    continue;
	}
	else if (mapl1->chksumvalid[idx] == CHKSUM_UNKNOWN)
	{
	    pms->n_chksum1++;
	    mm_chk.pid = pid1;
	    mm_chk.addr = start1 + idx * PAGE_SIZE;
	    if (ioctl(mergemod_fd, MERGEMEM_GEN_CHECKSUM, &mm_chk) != MERGEMEM_SUCCESS)
	    {
		mapl1->chksumvalid[idx] = CHKSUM_FAILED;
		pms->n_error++;
		continue;
	    }
	    else
	    {
		mapl1->chksumvalid[idx] = CHKSUM_SUCCESS;
		mapl1->chksum[idx] = mm_chk.chksum;
	    }
	}

	/* if the page is valid for pid1, try do to the same for pid2 */
	if (mapl2->chksumvalid[idx] == CHKSUM_FAILED)
	{
	    pms->n_error++;
	    continue;
	}
	else if (mapl2->chksumvalid[idx] == CHKSUM_UNKNOWN)
	{
	    pms->n_chksum2++;
	    mm_chk.pid = pid2;
	    mm_chk.addr = start2 + idx * PAGE_SIZE;
	    if (ioctl(mergemod_fd, MERGEMEM_GEN_CHECKSUM, &mm_chk) != MERGEMEM_SUCCESS)
	    {
		mapl2->chksumvalid[idx] = CHKSUM_FAILED;
		pms->n_error++;
		continue;
	    }
	    else
	    {
		mapl2->chksumvalid[idx] = CHKSUM_SUCCESS;
		mapl2->chksum[idx] = mm_chk.chksum;
	    }
	}

	if (mapl1->chksum[idx] != mapl2->chksum[idx])
	{
	    pms->n_notmerged++;
	    continue;
	}

#ifdef DONT_DO_THE_ACTUAL_MERGE
	pms->n_already++;
	continue;
#endif
	/* both pages are present and their checksums are equal -> try to merge them */
	mm_mem.addr1 = start1 + idx * PAGE_SIZE;
	mm_mem.addr2 = start2 + idx * PAGE_SIZE;
	switch (ioctl(mergemod_fd, MERGEMEM_MERGE_MEM, &mm_mem))
	{
	    case MERGEMEM_SUCCESS:
		pms->n_merged++;
		break;

	    case MERGEMEM_ALREADYSH:
		pms->n_already++;
		break;

	    default:
		pms->n_error++;
		break;
	}
    }
}

/*-----------------------------------------------------------------------------------------------*/

static int
merge_pids_mem(MrgList *mrgl1)
{
    char        tried;
    int         saved_pages;
    merge_stat  ms;
    MrgList    *mrgl2;
    MapList    *mapl1, *mapl2;

    saved_pages = 0;
    for (mrgl2 = mrgl1->next; mrgl2; mrgl2 = mrgl2->next)
    {
	/*
	 *  For every pair of processes, we try to find the common anonymous mappings.
	 *  Every process has at least one anonymous mapping, once it gets here.
	 *  We only allow mappings of equal size, so this may exclude stack pages sometimes.
	 */
	tried = FALSE;
	memset(&ms, 0, sizeof(ms));

	/* TODO: find some way to avoid this double iteration */
	for (mapl1 = mrgl1->map; mapl1; mapl1 = mapl1->next)
	{
	    if ((mapl2 = find_map_in_maplist(mrgl2->map, mapl1)))
	    {
		merge_anon_maps(mapl1, mapl2, mrgl1->pid, mrgl2->pid, &ms);
		tried = TRUE;
	    }
	}
	saved_pages += ms.n_merged;
	if (tried)
	    logmsg(4, "--> (%5d,%5d) %3dp merged, %3dp shared, %3dp misses, %3dp not merged"
		   ", %3d + %3d checksums\n",
	      mrgl1->pid, mrgl2->pid, ms.n_merged, ms.n_already, ms.n_error,
		   ms.n_notmerged, ms.n_chksum1, ms.n_chksum2);
    }

    return saved_pages;
}

/*-----------------------------------------------------------------------------------------------*/

#define MMAPLEN		1000

static MapList *
read_pid_map(pid_t pid)
{
    char        line[MMAPLEN];
    char        mapsfile[32];
    char        vmperm[16];
    ulong       vmstart, vmend, vmoff;
    int         dev_maj, dev_min;
    int         nchars, npages;
    dev_t       dev;
    ino_t       ino, ino1;
    MapList    *mapl, *map_list;
    FILE       *fp;

    sprintf(mapsfile, "/proc/%d/maps", pid);
    if ((fp = fopen(mapsfile, "r")) == 0)
	return 0;

    ino = dev = 0;
    map_list = 0;
    while (fgets(line, MMAPLEN, fp))
    {
	if (sscanf(line, "%lx-%lx %s %lx %d:%d %n", &vmstart, &vmend, vmperm, &vmoff,
		   &dev_maj, &dev_min, &nchars) != 6)
	    break;

	if (line[nchars] == '/')
	{
	    /*
	     * FIXME: This kernel is 2.1.xx or higher. Instead of the inode number
	     * the full name is shown in the map. We must resolve this again to find
	     * the dev:ino numbers. Pretty lame if you realize what the kernel just did..
	     *
	     * Perhaps this should be delegated to the mergemod.o module.
	     */
	    continue;
	}
	else
	{
	    if ((ino1 = atoi(line + nchars)))
	    {
		/* file mapping: skip it, but remember ino and dev */
		dev = MKDEV(dev_maj, dev_min);
	    }
	    else if (ino != 0 || vmend == STACK_TOP)
	    {
		if (ino == 0)
		    dev = 0;
		/* anonymous mapping: either a stack page or right after file mapping */
		if (lib_list && !find_lib_by_devino(lib_list, dev, ino))
		    continue;

		logmsg(8, "[%d] Found anon map for %x:%lu, at " PTRFMT " - " PTRFMT "\n",
		       pid, dev, ino, vmstart, vmend);
		npages = (vmend - vmstart) / PAGE_SIZE;
		mapl = safe_malloc(sizeof(*mapl));
		mapl->vmstart = vmstart;
		mapl->vmend = vmend;
		mapl->dev = dev;
		mapl->ino = ino;
		mapl->chksum = safe_malloc(sizeof(ulong) * npages);
		mapl->chksumvalid = safe_malloc(npages);

		mapl->next = map_list;
		map_list = mapl;
	    }
	    ino = ino1;
	}
    }
    fclose(fp);

    return map_list;
}

/*-----------------------------------------------------------------------------------------------*/

/*
 * This will build mrg_list, which is a list of all pids to be merged.
 * This takes into account pid_list and cmd_list. If a pid has become invalid
 * because the process no longer exists, it will be removed from pid_list. However,
 * if there are no processes with a name in cmd_list, cmd_list is left untouched
 * because there may be some processes of that name later on.
 * The commands in ign_list are never added to mrg_list.
 */

static void
walk_proc_dir(int cmdonly)
{
    char        statfile[32];
    char        commname[64];
    char        nameok, cmdok;
    pid_t       pid, self;
    PidList    *pidl, **ppidl;
    MrgList    *mrgl;
    MapList    *mapl;
    CmdList    *cmdl;
    DIR        *dir;
    FILE       *fp;
    struct dirent *de;

    if ((dir = opendir("/proc")) == NULL)
	err_exit(ERR_NOPROC, "Could not open /proc directory.\n");

    self = getpid();
    while ((de = readdir(dir)))
    {
	if ((pid = atoi(de->d_name)) == 0 || pid == self)
	    continue;

	sprintf(statfile, "/proc/%d/stat", pid);
	if ((fp = fopen(statfile, "r")) == 0)
	    continue;

	nameok = (fscanf(fp, "%*d (%[^)]", commname) == 1);
	fclose(fp);
	if (!nameok)
	    continue;

	/* for cleaning purposes... pidl is also used a few lines further! */
	pidl = find_cmd_by_pid(pid_list, pid);
	if (pidl)
	    pidl->valid = 1;

	/* skip this process if it's on the ignore list */
	if (find_cmd_by_name(ign_list, commname))
	    continue;

	cmdok = mergeall && !cmdonly;
	if (!cmdok)
	{
	    /* processes specified by name must have an expired timer */
	    if ((!cmdonly && pidl) ||
		((cmdl = find_cmd_by_name(cmd_list, commname)) && !cmdl->counter))
		cmdok = TRUE;
	}

	/* nearly there: add it to the list only if it has a usable map */
	if (cmdok && (mapl = read_pid_map(pid)))
	{
	    mrgl = safe_malloc(sizeof(*mrgl));
	    mrgl->pid = pid;
	    mrgl->map = mapl;

	    mrgl->next = mrg_list;
	    mrg_list = mrgl;
	}
	else
	    logmsg(2, "skipping pid %d\n", pid);
    }
    closedir(dir);

    /* remove non-existent processes from pid_list */
    ppidl = &pid_list;
    for (pidl = pid_list; pidl; pidl = *ppidl)
    {
	if (!pidl->valid)
	{
	    logmsg(2, "removing pid %d\n", pidl->pid);
	    *ppidl = pidl->next;
	    free(pidl);
	}
	else
	{
	    pidl->valid = 0;	/* for the next time we get called */
	    ppidl = &pidl->next;
	}
    }
}

/*===============================================================================================*/

static void
merge_once(int cmdonly)
{
    int         saved_pages;
    MrgList    *mrgl;
    MapList    *mapl, *nmapl;

    /* we need to build the entire merge list first, so we can merge pairwise */
    walk_proc_dir(cmdonly);

    /* merge and free at the same time :^) */
    saved_pages = 0;
    while (mrg_list)
    {
	saved_pages += merge_pids_mem(mrg_list);
	for (mapl = mrg_list->map; mapl; mapl = nmapl)
	{
	    nmapl = mapl->next;
	    free(mapl->chksumvalid);
	    free(mapl->chksum);
	    free(mapl);
	}
	mrgl = mrg_list->next;
	free(mrg_list);
	mrg_list = mrgl;
    }

    logmsg(1, "Saved pages: %d (%d KB)\n", saved_pages, saved_pages * PAGE_SIZE / 1024);
}

/*-----------------------------------------------------------------------------------------------*/

static void
run_daemon(void)
{
    int         counter;
    char        domerge;
    CmdList    *cmdl;

    close(0);
    close(1);
    close(2);
    setsid();
    signal(SIGHUP, normal_exit);
    signal(SIGTERM, normal_exit);
    setpriority(PRIO_PROCESS, 0, nicelevel);

    logmsg(1, "Daemon started, pid %d\n", getpid());

    counter = 0;
    while (1)
    {
	domerge = FALSE;
	if (++counter == interval)
	{
	    counter = 0;
	    domerge = TRUE;
	}
	for (cmdl = cmd_list; cmdl; cmdl = cmdl->next)
	{
	    if (++cmdl->counter == cmdl->interval)
	    {
		cmdl->counter = 0;
		domerge = TRUE;
	    }
	}
	if (domerge)
	    merge_once(counter);
	sleep(1);
    }
}

/*===============================================================================================*/

int
main(int argc, char *argv[])
{
    int         retval, modver;

    if (argc == 1)
	mergeall = TRUE;
    else
	parse_options(argc, argv);

    mergemod_fd = open("/dev/mergemem", O_RDONLY);
    if (mergemod_fd == -1)
	err_exit(ERR_MODULE, "Could not open /dev/mergemem\n");

    /* prepare for logging messages */
    if (loglevel > 0 && (flog = fopen(logfile, "a")) == 0)
    {
	fprintf(stderr, "Could not open logfile for appending.\n");
	if (run_as_daemon)
	    loglevel = 0;
    }
    if (flog)
	setvbuf(flog, 0, _IONBF, 0);

    /* check the version */
    retval = ioctl(mergemod_fd, MERGEMEM_CHECK_VER, &modver);
    if (retval || modver != MOD_VERSION)
    {
	logmsg(1, "module %d incompatible with program (expected %d)\n", modver, MOD_VERSION);
	err_exit(ERR_VERSION, "module incompatible with program\n");
    }
    logmsg(2, "mergemod module has version %d\n", modver);
    logmsg(1, "mergemem started\n");

    if (run_as_daemon)
    {
	switch (fork())
	{
	    case -1:
		err_exit(ERR_FORK, "Could not fork daemon.\n");
		break;

	    case 0:
		run_daemon();
		break;
	}
    }
    else
    {
	merge_once(FALSE);
	normal_exit(0);
    }

    return 0;
}
