/*
  mergemod.c
  Kernel module for 2.[02].x Kernels
  
  This file is part of mergemem by Philipp Richter & Philipp Reisner

  mergemem is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.
  
  mergemem is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with mergemem; see the file COPYING.  If not, write to
  the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.

*/
/* #define RED_HAT_KERNEL 1 */


#if(K_VERSION!=2)
#error "Only 2.x.x Kernels supported"
#endif

#if defined(K_ARCH_alpha) && defined(RED_HAT_KERNEL)
#define READ_TYPE long
#define READ_COUNT_TYPE unsigned long
#else
#define READ_TYPE int
#define READ_COUNT_TYPE int
#endif

#if(K_PATCHLEVEL==0)
#define MERGEMOD_20
#endif

#ifdef HAVE_AUTOCONF
#include <linux/autoconf.h> /* for CONFIG_MODVERSIONS */
#endif

#include <asm/smp.h>
#ifndef MERGEMOD_20
#include <asm/current.h>
#endif

#include <linux/config.h>
#include <linux/module.h>
#ifdef CONFIG_MODVERSIONS
#include <linux/modversions.h>
#endif
#include <linux/kernel.h>
#include <linux/fs.h>
#include <linux/proc_fs.h>
#include <linux/sched.h>
#include <linux/swap.h>
#include <linux/miscdevice.h>

#include <asm/page.h>
#ifdef MERGEMOD_20
#include <asm/segment.h>
#else
#include <asm/uaccess.h>
#endif
#include <asm/io.h>
#include <asm/pgtable.h>

#include "mergemod.h"

typedef struct
{
  const void * v_addr;   /* virtual address  */
  const void * p_addr;   /* physical address */
  size_t count;          /* reference count  */
  unsigned long hash;    /* hash             */
} MPAGEINFO;


#ifdef MERGEMOD_20
/* compatibility with more recent kernels */
static inline struct task_struct * find_task_by_pid(int pid)
{
  struct task_struct * tsk, * st;

  st = tsk = current_set[0]; /* current instead?? */

  do
  {
    if(tsk->pid == pid) return tsk;
    tsk=tsk->next_task;
  }
  while(tsk != st);

  return 0;
}

/* from autofs_i.h */
static inline int copy_to_user(void *dst, void *src, unsigned long len)
{
  int rv = verify_area(VERIFY_WRITE, dst, len);
  if ( rv )
    return -1;
  memcpy_tofs(dst,src,len);
  return 0;
}

static inline int copy_from_user(void *dst, void *src, unsigned long len)
{
  int rv = verify_area(VERIFY_READ, src, len);
  if ( rv )
    return -1;
  memcpy_fromfs(dst,src,len);
  return 0;
}

#define atomic_read(x) (*(x))
#endif

static inline int same_uid(struct task_struct* p)
{
  if( !suser()
      && (current->euid ^ p->suid)
      && (current->euid ^ p->uid)
      && (current->uid ^ p->suid)
      && (current->uid ^ p->uid) )
    return 0;
  else
    return 1;
}

static inline int nopage(unsigned long page)
{
  return !page || page >= high_memory;
}

static char * get_phys_addr(struct task_struct *tsk, unsigned long addr, pte_t **pe)
{
  pgd_t * page_dir;
  pmd_t * page_middle;
  pte_t pte;

  page_dir = pgd_offset(tsk->mm,addr);
  if (pgd_none(*page_dir)) return NULL;
  if (pgd_bad(*page_dir)) 
    {
      printk("Bad page dir entry %08lx\n", pgd_val(*page_dir));
      pgd_clear(page_dir);
      return NULL;
    }
  
  page_middle = pmd_offset(page_dir,addr);
  if (pmd_none(*page_middle)) return NULL;
  if (pmd_bad(*page_middle)) 
    {
      printk("Bad page middle entry %08lx\n", pmd_val(*page_middle));
      pmd_clear(page_middle);
      return NULL;
    }
  
  pte = *pte_offset(page_middle,addr);
  if (!pte_present(pte)) return NULL;

  *pe = pte_offset(page_middle,addr);

  return (char *) pte_page(pte);
}

static inline int noanonymous_mapping( struct vm_area_struct *vma)
{
#ifdef MERGEMOD_20
  return (vma->vm_flags & (VM_LOCKED|VM_SHM)) || vma->vm_inode;
#else
  return (vma->vm_flags & (VM_IO|VM_LOCKED|VM_SHM)) || vma->vm_file;
#endif

}


#ifdef MERGEMOD_20
#define close_type void
#define close_return return
#else
#define close_type int
#define close_return return 0
#endif

#ifdef MERGEMOD_20
#else
void (*mergemem_profile_hook1)(mem_map_t * ) = NULL;
EXPORT_SYMBOL(mergemem_profile_hook1);
MODULE_AUTHOR("Philipp Reisner <e9525415@stud2.tuwien.ac.at>");
MODULE_DESCRIPTION("mergemod - merge memory of running processes");
#endif

/* statistics for proc filesys */
static unsigned long stat_merged = 0L;
static unsigned long stat_alreadysh = 0L;
static unsigned long stat_notequ = 0L;
static unsigned long stat_moremap = 0L;

static struct mergemem_chksum mm_chksum;
static struct mergemem_mmem mm_mmem;
static struct mergemem_get_phys_addr mm_get_phys_addr;
static struct mergemem_get_page mm_get_page;

/* proc file ops */
static int mergemem_proc_open(struct inode*, struct file*);
static close_type mergemem_proc_close(struct inode*,struct file*);
#ifdef MERGEMOD_20
static ssize_t mergemem_proc_read(struct inode*, struct file*, char*, READ_COUNT_TYPE);
#else
static ssize_t mergemem_proc_read(struct file*, char*, size_t, loff_t* );
#endif

/* device ops */
static int mergemem_dev_open(struct inode*, struct file*);
static close_type mergemem_dev_close(struct inode*,struct file*);
static int mergemem_dev_ioctl(struct inode*, struct file*, unsigned int cmd, unsigned long arg);

/* the important routines... */
static int mergemem(pid_t, const void *, pid_t, const void * );
static int genchksum(pid_t, const void *addr, unsigned long*, int* );
static int getphysaddr( pid_t, const void *, const void * * );
static int provide_page(pid_t,const void * v_addr, void* );

static unsigned long m_hash_addrothalf(const void*, const size_t);
static unsigned long m_hash_const(const void*, const size_t);

static struct file_operations mergemem_dev_fops =
{
  NULL,                   /* lseek - default */
  NULL,                   /* read */
  NULL,                   /* write  */
  NULL,                   /* readdir */
  NULL,                   /* select - default */
  mergemem_dev_ioctl,     /* ioctl - default */
  NULL,                   /* mmap */
  mergemem_dev_open,      /* open code */
#ifdef MERGEMOD_20
#else
  NULL,                   /* flush */
#endif
  mergemem_dev_close,     /* release code */
  NULL                    /* can't fsync */
};

static struct file_operations mergemem_proc_fops =
{
  NULL,                   /* lseek - default */
  mergemem_proc_read,     /* read */
  NULL,                   /* write  */
  NULL,                   /* readdir */
  NULL,                   /* select - default */
  NULL,                   /* ioctl - default */
  NULL,                   /* mmap */
  mergemem_proc_open,     /* open code */
#ifdef MERGEMOD_20
#else
  NULL,                   /* flush */
#endif
  mergemem_proc_close,    /* release code */
  NULL                    /* can't fsync */
};

static struct inode_operations mergemem_proc_iops =
{
  &mergemem_proc_fops,    /* my own file-ops */
  NULL,                   /* create */
  NULL,                   /* lookup */
  NULL,                   /* link */
  NULL,                   /* unlink */
  NULL,                   /* symlink */
  NULL,                   /* mkdir */
  NULL,                   /* rmdir */
  NULL,                   /* mknod */
  NULL,                   /* rename */
  NULL,                   /* readlink */
  NULL,                   /* follow_link */
  NULL,                   /* readpage */
  NULL,                   /* writepage */
  NULL,                   /* bmap */
  NULL,                   /* truncate */
  NULL                    /* permission */
};

struct proc_dir_entry mergemem_proc_dir =
{
  0, 8, "mergemem",
  S_IFREG | S_IRUGO , 1, 0, 0,
  0, &mergemem_proc_iops,
  NULL, NULL,
  NULL,
  NULL, NULL
};

static struct miscdevice mergemem_dev =
{
  MERGEMEM_MINOR, "mergemem", &mergemem_dev_fops
};

/* proc file operations */
static int mergemem_proc_open(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/proc/mergemem opened.\n");
#endif
  MOD_INC_USE_COUNT;
  return 0;
}

static close_type mergemem_proc_close(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/proc/mergemem closed.\n");
#endif
  MOD_DEC_USE_COUNT;
  close_return;
}

#ifdef MERGEMOD_20
static ssize_t mergemem_proc_read(struct inode *inode, struct file *file, char *buf, READ_COUNT_TYPE count)
#else
static ssize_t mergemem_proc_read(struct file *file, char *buf, size_t count, loff_t *unused)
#endif
{
  int fpos = file->f_pos;
  char buffer[200];

#ifdef DEBUG
  printk(MODULE_NAME"reading...\n");
#endif

  sprintf(buffer,
          "version       : %d\n"
	  "merged        : %lu pages, %lu kb\n"
	  "not equal     : %lu pages, %lu kb\n"
	  "already shared: %lu pages, %lu kb\n"
	  "mappings > 1  : %lu pages, %lu kb\n", MOD_VERSION,
	  stat_merged, stat_merged * PAGE_SIZE / 1024,
	  stat_notequ, stat_notequ * PAGE_SIZE / 1024,
	  stat_alreadysh, stat_alreadysh * PAGE_SIZE / 1024,
	  stat_moremap, stat_moremap * PAGE_SIZE / 1024 );

  if(fpos > strlen(buffer) - 1) return 0;
  if((strlen(buffer) - fpos) < count)
    count = strlen(buffer) - fpos;
  copy_to_user(buf, buffer+fpos, count);
  file->f_pos = file->f_pos+count;

  return count;
}

/* device ops */
static int mergemem_dev_open(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/dev/mergemem opened.\n");
#endif
  MOD_INC_USE_COUNT;
  return 0;
}

static close_type mergemem_dev_close(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/dev/mergemem closed.\n");
#endif
  MOD_DEC_USE_COUNT;
  close_return;
}

static int mergemem_dev_ioctl(struct inode *inode, struct file *file, unsigned int cmd, unsigned long arg)
{
  int retval;

  switch(cmd)
  {
    /* version checking */
    case MERGEMEM_CHECK_VER:
      {
	int version = MOD_VERSION;
	if(copy_to_user((int *)arg,&version,sizeof(int)))
	  return -EFAULT;
	return 0;
      }
      /* generate a checksum for a page */
    case MERGEMEM_GEN_CHECKSUM:
      if(copy_from_user(&mm_chksum, (struct mergemem_chksum*)arg,
			sizeof(struct mergemem_chksum))) return -EFAULT;
      retval = genchksum( mm_chksum.pid, (const void*) mm_chksum.addr, &(mm_chksum.chksum),
			  &(mm_chksum.nrefs));
      if(copy_to_user((struct mergemem_chksum*)arg, &mm_chksum,
		      sizeof(struct mergemem_chksum))) return -EFAULT;
      return retval;

    case MERGEMEM_MERGE_MEM:
      if(copy_from_user(&mm_mmem, (struct mergemem_mmem*)arg,
			sizeof(struct mergemem_mmem))) return -EFAULT;
      retval = mergemem( mm_mmem.pid1, (const void *) mm_mmem.addr1,
			 mm_mmem.pid2, (const void *) mm_mmem.addr2 );
      return retval;

    case MERGEMEM_GET_PHYS_ADDR:
      if(copy_from_user(&mm_get_phys_addr, (struct mergemem_get_phys_addr*)arg,
			sizeof(struct mergemem_get_phys_addr))) return -EFAULT;
      retval = getphysaddr( mm_get_phys_addr.pid, (const void*) mm_get_phys_addr.addr,
			    (const void **) &(mm_get_phys_addr.phys_addr));
      if(copy_to_user((struct mergemem_get_phys_addr*)arg, &mm_get_phys_addr,
		      sizeof(struct mergemem_get_phys_addr))) return -EFAULT;
      return retval;

    case MERGEMEM_GET_PAGE:
      if(copy_from_user(&mm_get_page, (struct mergemem_get_page*)arg,
			sizeof(struct mergemem_get_page))) return -EFAULT;
      retval = provide_page(mm_get_page.pid,(const void*) mm_get_page.addr,mm_get_page.page);
      return retval;

    case MERGEMEM_RESET_STAT:
      stat_merged = 0L;
      stat_alreadysh = 0L;
      stat_notequ = 0L;
      stat_moremap = 0L;
      return 0;

    default:
      return -EINVAL;
  }
  return 0;
}


int get_pageinfo(pid_t, MPAGEINFO *ip, struct vm_area_struct * *, pte_t ** );


/****************************************************************************/
/* From here on, we have the implementation of the page fusion function     */

static int mergemem(pid_t pid1, const void * addr1, pid_t pid2, const void * addr2)
{
  pte_t * pte1, * pte2, pte;
  mem_map_t * mape1, * mape2;
  struct vm_area_struct * vma1, * vma2;
  MPAGEINFO i1, i2;
  int retval1=MERGEMEM_SUCCESS;
  int retval;

  i1.v_addr = addr1;
  i2.v_addr = addr2;
  retval = get_pageinfo( pid1, &i1, &vma1, &pte1 );
  if (!retval)
    return retval;
  retval = get_pageinfo( pid2, &i2, &vma2, &pte2 );
  switch(retval)
    {
    case MERGEMEM_NOTASK1:
      return MERGEMEM_NOTASK2;
    case MERGEMEM_NOPAGE1:
      return MERGEMEM_NOPAGE2;
    case MERGEMEM_SUCCESS:
      break;
    default:
      return retval;
    }

  if(i1.p_addr == i2.p_addr)
    {
      stat_alreadysh++;
      return MERGEMEM_ALREADYSH;
    }

  mape1 = &mem_map[MAP_NR(i1.p_addr)];
  mape2 = &mem_map[MAP_NR(i2.p_addr)];


  if(mape1->inode != NULL ||
     mape2->inode != NULL )
    return MERGEMEM_NOANO;

  if(PageLocked(mape1) || PageLocked(mape2))
    return MERGEMEM_PLOCKED;

  cli(); /* start of critical section */
  if(memcmp(i1.p_addr,i2.p_addr,PAGE_SIZE))
    {
      sti(); /* possible end of critical section */
      stat_notequ++;
      return MERGEMEM_NOTEQUAL;
    }

  /* Do the actual work... */

  /*  new pte's are readonly and point to page1 */
  pte = pte_wrprotect(*pte1);

  /*  increase the count in the according mem_map structure */
  atomic_inc(&mape1->count);

  /* page1 must go off the swap_cache, since it is now mapped more
   * than one time
   */
#ifdef MERGEMOD_20
  if(delete_from_swap_cache(MAP_NR(i1.p_addr)))
    pte = pte_mkdirty(pte);
#else
  if(PageSwapCache(mape1))
    {
      delete_from_swap_cache(mape1);
      pte = pte_mkdirty(pte);        /* Is this strictly necessary? */
    }
#endif

  set_pte(pte1,pte);
  set_pte(pte2,pte_mkold(pte));

#ifdef MERGEMOD_20
#else
  flush_tlb_page(vma1,i1.v_addr);
  flush_tlb_page(vma2,i2.v_addr);
#endif

  if(atomic_read(&mape2->count) == 1)
    {
      stat_merged++;
#ifdef MERGEMOD_20
#else
      if(mergemem_profile_hook1) mergemem_profile_hook1(mape1);
#endif
    }
  else
    {
      stat_moremap++;
      retval1 = MERGEMEM_MOVED;
    }

  sti(); /* possible end of critical section */
  /*  free or decrease count of page2 */
  free_page((unsigned long)i2.p_addr);

  /*  We would need to decreas the rss counter, if we would merge
      that page to the global ZERO_PAGE.
      ( Since rss is only increased in do_wp_page (copy on write) if the
      original page is the ZERO_PAGE. rss is increased in do_no_page)
  tsk2->mm->rss--;
  */

  return retval1;
}

int get_pageinfo(pid_t pid,
		 MPAGEINFO * ip,
	      struct vm_area_struct * * const vmap,
	      pte_t ** ptep)
{
  struct task_struct * tsk;
  struct vm_area_struct * vma;
  unsigned long p_addr;
  unsigned long addr = (unsigned long) ip->v_addr;

  tsk = find_task_by_pid(pid);
  if(!tsk) return MERGEMEM_NOTASK1;
  if (!same_uid(tsk))
    return MERGEMEM_PERMISSION;

  vma = find_vma(tsk->mm,addr);
  if(!vma)
    return MERGEMEM_NOPAGE1;
  if ( noanonymous_mapping(vma) )
    return MERGEMEM_NOANO;
  p_addr = (unsigned long) get_phys_addr(tsk,addr,ptep);
  if( nopage(p_addr) )
    return MERGEMEM_NOPAGE1;

#ifdef MERGEMOD_20
  if( mem_map[MAP_NR(p_addr)].inode != NULL )
    return MERGEMEM_NOANO;
#endif

  *vmap = vma;
  ip->p_addr = (const void *) p_addr;
  ip->count = atomic_read(&mem_map[MAP_NR(p_addr)].count);
  return MERGEMEM_SUCCESS;
}

/* Calculate checksum for page */
static int genchksum(pid_t pid, const void * v_addr, unsigned long * ret_chksum,
		     int * countp)
{
  int retval;
  struct vm_area_struct * vma;
  pte_t * pte;
  MPAGEINFO i;
  i.v_addr = v_addr;
  retval = get_pageinfo(pid, &i, &vma, &pte);
  if (!retval)
    return retval;
  *ret_chksum = m_hash_addrothalf(i.p_addr,PAGE_SIZE);
  return MERGEMEM_SUCCESS;
}

static int getphysaddr(pid_t pid, const void * v_addr, const void * * p_addrp)
{
  struct vm_area_struct * vma;
  pte_t * pte;
  MPAGEINFO i;
  int retval;

  i.v_addr = v_addr;
  retval = get_pageinfo(pid, &i, &vma, &pte);
  if (!retval)
    * p_addrp = NULL;
  else
    * p_addrp = i.p_addr;
  return retval;
}

int init_module()
{
  /* register device */
  if(misc_register(&mergemem_dev))
  {
    printk(MODULE_NAME"unable to register device.\n");
    return -EIO;
  }
  /* register proc */
  if(proc_register(&proc_root, &mergemem_proc_dir)) /* _dynamic? */
    {
      misc_deregister(&mergemem_dev);
      printk(MODULE_NAME"unable to register proc file.\n");
      return -EIO;
    }

  printk(MODULE_NAME"module initialised. Version: %d\n",MOD_VERSION);
  return 0;
}

void cleanup_module()
{
  proc_unregister(&proc_root, mergemem_proc_dir.low_ino);
  misc_deregister(&mergemem_dev);
  printk(MODULE_NAME"module released.\n");
}

static int provide_page(pid_t pid,  const void * v_addr, void * to)
{
  struct vm_area_struct * vma;
  pte_t * pte;
  MPAGEINFO i;
  int retval;

  i.v_addr = v_addr;
  retval = get_pageinfo(pid, &i, &vma, &pte);
  if (!retval)
    return retval;
  if(copy_to_user(to,(void *)i.p_addr,PAGE_SIZE))
    return -EFAULT;
  return MERGEMEM_SUCCESS;
}


/* Predefined hashfunctions */

#define BITS_OF_UNSIGNEDLONG (sizeof(unsigned long)*8)

static unsigned long
m_hash_addrothalf(const void* addr, const size_t size)
{
  unsigned long * page = (unsigned long *) addr;
  unsigned long * end = (unsigned long *) ( (char *) addr + size/2 );
  unsigned long chksum = 0;
  while (page < end)
    {
      chksum = (chksum << 3) + (chksum >> (BITS_OF_UNSIGNEDLONG-3)) + *page;
      page++;
    }
  return chksum;
}

static unsigned long
m_hash_const(const void* addr, size_t size)
{
  return 27;
}












