/*
  mergemod_20.c
  Kernel module for 2.0.x Kernels
  
  This file is part of mergemem by Philipp Richter & Philipp Reisner

  mergemem is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; either version 2, or (at your option)
  any later version.
  
  mergemem is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.

  You should have received a copy of the GNU General Public License
  along with mergemem; see the file COPYING.  If not, write to
  the Free Software Foundation, 675 Mass Ave, Cambridge, MA 02139, USA.

*/

#ifdef HAVE_AUTOCONF
#include <linux/autoconf.h> /* for CONFIG_MODVERSIONS */
#endif   
#include <linux/config.h>
#include <linux/module.h>
#ifdef CONFIG_MODVERSIONS
#include <linux/modversions.h>
#endif    
#include <linux/kernel.h>
#include <linux/fs.h>
#include <linux/proc_fs.h>
#include <linux/sched.h>
#include <linux/swap.h>
#include <linux/miscdevice.h>

#include <asm/page.h>
#include <asm/segment.h>
#include <asm/io.h>
#include <asm/pgtable.h>

#include "mergemod.h"

#define BITS_OF_LONG (sizeof(long)*8)

static int mergemem_inode;

/* statistics for proc filesys */
static unsigned long stat_merged = 0L;
static unsigned long stat_alreadysh = 0L;
static unsigned long stat_notequ = 0L;
static unsigned long stat_moremap = 0L;

static struct mergemem_chksum mm_chksum;
static struct mergemem_mmem mm_mmem;
static struct mergemem_get_phys_addr mm_get_phys_addr;

/* proc file ops */
static int mergemem_proc_open(struct inode*, struct file*);
static void mergemem_proc_close(struct inode*,struct file*);
static READ_TYPE mergemem_proc_read(struct inode*, struct file*, char*, READ_COUNT_TYPE);

/* device ops */
static int mergemem_dev_open(struct inode*, struct file*);
static void mergemem_dev_close(struct inode*,struct file*);
static int mergemem_dev_ioctl(struct inode*, struct file*, unsigned int cmd, unsigned long arg);

/* the important routines... */
static int mergemem(int, unsigned long, int, unsigned long );
static int genchksum(int, unsigned long, unsigned long*, int* );
static int getphysaddr( pid_t, unsigned long, unsigned long * );

static struct file_operations mergemem_dev_fops =
{
  NULL,                   /* lseek - default */
  NULL,                   /* read */
  NULL,                   /* write  */
  NULL,                   /* readdir */
  NULL,                   /* select - default */
  mergemem_dev_ioctl,     /* ioctl - default */
  NULL,                   /* mmap */
  mergemem_dev_open,      /* open code */
  mergemem_dev_close,     /* release code */
  NULL                    /* can't fsync */
};

static struct file_operations mergemem_proc_fops = 
{
  NULL,                   /* lseek - default */
  mergemem_proc_read,     /* read */
  NULL,                   /* write  */
  NULL,                   /* readdir */
  NULL,                   /* select - default */
  NULL,                   /* ioctl - default */
  NULL,                   /* mmap */
  mergemem_proc_open,     /* open code */
  mergemem_proc_close,    /* release code */
  NULL                    /* can't fsync */
};

static struct inode_operations mergemem_proc_iops =
{
  &mergemem_proc_fops,    /* my own file-ops */
  NULL,                   /* create */
  NULL,                   /* lookup */
  NULL,                   /* link */
  NULL,                   /* unlink */
  NULL,                   /* symlink */
  NULL,                   /* mkdir */
  NULL,                   /* rmdir */
  NULL,                   /* mknod */
  NULL,                   /* rename */
  NULL,                   /* readlink */
  NULL,                   /* follow_link */
  NULL,                   /* readpage */
  NULL,                   /* writepage */
  NULL,                   /* bmap */
  NULL,                   /* truncate */
  NULL                    /* permission */
};
  
struct proc_dir_entry mergemem_proc_dir = 
{
  0, 8, "mergemem",
  S_IFREG | S_IRUGO , 1, 0, 0,
  0, &mergemem_proc_iops,
  NULL, NULL,
  NULL,
  NULL, NULL
};

static struct miscdevice mergemem_dev =
{
  MERGEMEM_MINOR, "mergemem", &mergemem_dev_fops
};

int init_module()
{
  /* register device */
  if(misc_register(&mergemem_dev))
  {
    printk(MODULE_NAME"unable to register device.\n");
    return -EIO;
  }
  /* register proc */
  if(proc_register_dynamic(&proc_root, &mergemem_proc_dir))
  {
    misc_deregister(&mergemem_dev);
    printk(MODULE_NAME"unable to register proc file.\n");
    return -EIO;
  }

  mergemem_inode = mergemem_proc_dir.low_ino;
  
  printk(MODULE_NAME"module initialised.\n");
  return 0;
}

void cleanup_module()
{
  proc_unregister(&proc_root, mergemem_inode);
  misc_deregister(&mergemem_dev);
  printk(MODULE_NAME"module released.\n");
}

/* proc file operations */
static int mergemem_proc_open(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/proc/mergemem opened.\n");
#endif
  MOD_INC_USE_COUNT;
  return 0;
}

static void mergemem_proc_close(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/proc/mergemem closed.\n");
#endif
  MOD_DEC_USE_COUNT;
}

static READ_TYPE mergemem_proc_read(struct inode *inode, struct file *file, char *buf, READ_COUNT_TYPE count)
{
  int fpos = file->f_pos;
  char buffer[200];

#ifdef DEBUG
  printk(MODULE_NAME"reading...\n");
#endif  

  sprintf(buffer,
	  "merged        : %lu pages, %lu kb\n"
	  "not equal     : %lu pages, %lu kb\n"
	  "already shared: %lu pages, %lu kb\n"
	  "mappings > 1  : %lu pages, %lu kb\n",
	  stat_merged, stat_merged * PAGE_SIZE / 1024,
	  stat_notequ, stat_notequ * PAGE_SIZE / 1024,
	  stat_alreadysh, stat_alreadysh * PAGE_SIZE / 1024,
	  stat_moremap, stat_moremap * PAGE_SIZE / 1024 );

  if(fpos > strlen(buffer) - 1) return 0;
  if((strlen(buffer) - fpos) < count) 
    count = strlen(buffer) - fpos;
  memcpy_tofs(buf, buffer+fpos, count);
  file->f_pos = file->f_pos+count;

  return count;
}

/* device ops */
static int mergemem_dev_open(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/dev/mergemem opened.\n");
#endif
  MOD_INC_USE_COUNT;
  return 0;
}

static void mergemem_dev_close(struct inode *inode, struct file *file)
{
#ifdef DEBUG
  printk(MODULE_NAME"/dev/mergemem closed.\n");
#endif
  MOD_DEC_USE_COUNT;
}

static int mergemem_dev_ioctl(struct inode *inode, struct file *file, unsigned int cmd, unsigned long arg)
{
  int retval;

  switch(cmd)
  {
    /* version checking */
    case MERGEMEM_CHECK_VER:
      retval = verify_area(VERIFY_WRITE, (void *)arg, sizeof(int));
      if(retval)
	return(retval);
      put_user(MOD_VERSION, (int *)arg);
      return 0;
      
      /* generate a checksum for a page */
    case MERGEMEM_GEN_CHECKSUM:
      /* does VERIFY_WRITE implie VERIFY_READ ??? */
      retval = verify_area(VERIFY_WRITE, (void *)arg, sizeof(struct mergemem_chksum));
      if(retval)
	return(retval);
      memcpy_fromfs(&mm_chksum, (struct mergemem_chksum*)arg, sizeof(struct mergemem_chksum));
      retval = genchksum( mm_chksum.pid, mm_chksum.addr, &(mm_chksum.chksum), &(mm_chksum.nrefs));
      memcpy_tofs((struct mergemem_chksum*)arg, &mm_chksum, sizeof(struct mergemem_chksum));
      return retval;
      
    case MERGEMEM_MERGE_MEM:
      retval = verify_area(VERIFY_READ, (void *)arg, sizeof(struct mergemem_mmem));
      if(retval)
	return(retval);
      memcpy_fromfs(&mm_mmem, (struct mergemem_mmem*)arg, sizeof(struct mergemem_mmem));
      retval = mergemem( mm_mmem.pid1, mm_mmem.addr1, mm_mmem.pid2, mm_mmem.addr2 );
      return retval;

    case MERGEMEM_GET_PHYS_ADDR:
      retval = verify_area(VERIFY_WRITE, (void *)arg, sizeof(struct mergemem_get_phys_addr));
      if(retval)
	return(retval);
      memcpy_fromfs(&mm_get_phys_addr, (struct mergemem_get_phys_addr*)arg, sizeof(struct mergemem_get_phys_addr));
      retval = getphysaddr( mm_get_phys_addr.pid, mm_get_phys_addr.addr, &(mm_get_phys_addr.phys_addr));
      memcpy_tofs((struct mergemem_get_phys_addr*)arg, &mm_get_phys_addr, sizeof(struct mergemem_get_phys_addr));
      return retval;

    case MERGEMEM_RESET_STAT:
      stat_merged = 0L;
      stat_alreadysh = 0L;
      stat_notequ = 0L;
      stat_moremap = 0L;
      return 0;
      
    default:
      return -EINVAL;
  }
  return 0;
}

/****************************************************************************/
/* From here on, we have the implementation of the page fusion function     */


static struct task_struct * get_task(int pid)
{
  struct task_struct * tsk, * st;

  st = tsk = current_set[0];
  
  do
  {
    if(tsk->pid == pid) return tsk;
    tsk=tsk->next_task;
  }
  while(tsk != st);
  
  return 0;
}

static char * get_phys_addr(struct task_struct *tsk, unsigned long addr, pte_t **pe)
{
  pgd_t * page_dir;
  pmd_t * page_middle;
  pte_t pte;

  page_dir = pgd_offset(tsk->mm,addr);
  if (pgd_none(*page_dir)) return (char *)0;
  if (pgd_bad(*page_dir)) 
    {
      printk("Bad page dir entry %08lx\n", pgd_val(*page_dir));
      pgd_clear(page_dir);
      return (char *)0;
    }
  
  page_middle = pmd_offset(page_dir,addr);
  if (pmd_none(*page_middle)) return (char *)0;
  if (pmd_bad(*page_middle)) 
    {
      printk("Bad page middle entry %08lx\n", pmd_val(*page_middle));
      pmd_clear(page_middle);
      return (char *)0;
    }
  
  pte = *pte_offset(page_middle,addr);
  if (!pte_present(pte)) return (char *)0;

  *pe = pte_offset(page_middle,addr);

  return (char *) pte_page(pte);
}

static int mergemem(int pid1, unsigned long addr1, int pid2, unsigned long addr2)
{
  struct task_struct *tsk1, *tsk2;
  char *page1, *page2;
  pte_t *pte1, *pte2, pte;
  mem_map_t *mape1, *mape2;
  int retval=MERGEMEM_SUCCESS;
  
  tsk1 = get_task(pid1);
  if(!tsk1) return MERGEMEM_NOTASK1;
  tsk2 = get_task(pid2);
  if(!tsk2) return MERGEMEM_NOTASK2;
  
  page1 = get_phys_addr(tsk1,addr1,&pte1);
  if(!page1 || (unsigned long)page1 >= high_memory ) return MERGEMEM_NOPAGE1;
  page2 = get_phys_addr(tsk2,addr2,&pte2);
  if(!page2 || (unsigned long)page2 >= high_memory ) return MERGEMEM_NOPAGE2;

  mape1 = &mem_map[MAP_NR(page1)];
  mape2 = &mem_map[MAP_NR(page2)];

  if(page1 == page2) 
  {
    stat_alreadysh++;
    return MERGEMEM_ALREADYSH;
  }

  if(mape1->inode != NULL || 
     mape2->inode != NULL )
    return MERGEMEM_NOANO;

  if(PageLocked(mape1) || PageLocked(mape2)) 
    return MERGEMEM_PLOCKED;
  
  /* cli(); start of critical section */
  if(memcmp(page1,page2,PAGE_SIZE))
    {
      /* sti();  possible end of critical section */
      stat_notequ++;
      return MERGEMEM_NOTEQUAL;
    }
  
  /* Do the actual work... */

  /*  new pte's are readonly and point to page1 */
  pte = pte_wrprotect(*pte1);

  /*  increase the count in the according mem_map structure */
  mape1->count++;

  /* page1 must go off the swap_cache, since it is now mapped more
   * than one time
   */
  if(delete_from_swap_cache(MAP_NR(page1)))
    pte = pte_mkdirty(pte);

  set_pte(pte1,pte);
  set_pte(pte2,pte_mkold(pte));

  if(mape2->count == 1) stat_merged++;
  else
    {
      stat_moremap++;
      retval=MERGEMEM_MOVED;
    }

  /* sti(); possible end of critical section */  
  /*  free or decrease count of page2 */
  free_page((unsigned long)page2);

  /*  We would need to decreas the rss counter, if we would merge
      that page to the global ZERO_PAGE.
      ( Since rss is only increased in do_wp_page (copy on write) if the 
      original page is the ZERO_PAGE. rss is increased in do_no_page)
  tsk2->mm->rss--; 
  */

  return retval;
}

static unsigned long trivial_chksum(unsigned long* page)
{
  unsigned long *end;
  unsigned long chksum=0;

  /* For enhanced performance, we use only the first half of a page
   * for the checksum 
   */  
  end = page + (PAGE_SIZE/sizeof(unsigned long))/2;     
  while(page < end)
    {
      chksum = (chksum << 3) + (chksum >> (BITS_OF_LONG-3)) + *page;
      page++;
    }
  return chksum;
}

/* Calculate checksum for page */
static int genchksum(int pid, unsigned long addr, unsigned long* ret_chksum,
		     int* ret_mappings)
{
  struct task_struct * tsk;
  unsigned long * page;
  pte_t * pe;

  tsk = get_task(pid);
  if(!tsk) return MERGEMEM_NOTASK1;

  page = (unsigned long *) get_phys_addr(tsk,addr,&pe);
  if(!page || (unsigned long)page >= high_memory ) return MERGEMEM_NOPAGE1;

  if( mem_map[MAP_NR(page)].inode != NULL )
    return MERGEMEM_NOANO;

  *ret_chksum = trivial_chksum(page);
  *ret_mappings = mem_map[MAP_NR(page)].count;

  return MERGEMEM_SUCCESS;
}

int getphysaddr( pid_t pid, unsigned long addr, unsigned long *phys_addr )
{
  struct task_struct *tsk;
  pte_t *pte;
  unsigned long page;
  
  tsk = get_task(pid);
  if(!tsk) return MERGEMEM_NOTASK1;
  page = (unsigned long)get_phys_addr(tsk, addr, &pte);
  if(page < high_memory) *phys_addr=page;
  else *phys_addr=0;

  return 0;
}





