/* UNIX shared memory example to find the sum of the first 1000 numbers
   in parallel using two processes, one to add even numbered elements and
   one to add odd numbered elements */

#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#include <sys/sem.h>
#include <stdio.h>
#include <errno.h>
#define array_size 1000                         /* no of elements in shared memory array */
extern char *shmat();
void spin_lock_init();                          /* as given earlier */
void spin_lock();
void spin_unlock();

int main()
{
  key_t shm_key;                          /* semaphore key used for semget() */
  int shmid, semid, pid;                  /* shared memory, semaphore, process id */
  char *shm;                              /* shared memory addr returned by shmat() */
  int *A, *addr, *sum;                    /* shared data variables */
  int partial_sum;                        /* partial sum of each process */
  int i;
  spin_lock_init(&semid);                 /* initialize semaphore set */
  shm_key = 0x567;                        /* shared memory segment, a positive integer*/
  shmid = shmget(shm_key, (array_size * sizeof(int) + 1),( IPC_CREAT|0600 ));
  /* create segment */
  if (shmid == -1) {
    perror("shmget");
    exit(1);
  }
  shm = shmat(shmid, NULL, 0);            /* map segment to process data space */
  /* returns address */
  if (shm == (char*) - 1) {               /* as a character */
    perror("shmat");
    exit(1);
  }
  addr = (int*)shm;                       /* starting address */
  sum = addr;                             /* accumulating sum */
  addr++;
  A = addr;                               /* array of numbers, A */
  *sum = 0;
  for (i = 0; i < array_size; i++) /* load array with numbers */
    *(A + i) = i+1;


  pid = fork();                           /* create child process */
  if (pid == 0) {                         /* child does this */
    partial_sum = 0;
    for (i = 0; i < array_size; i = i + 2)
      partial_sum += *(A + i);
  }
  else {                                  /* parent does this */
    partial_sum = 0;
    for (i = 1; i < array_size; i = i + 2)
      partial_sum += *(A + i);
  }


  spin_lock(&semid);                      /* for each process, add partial sum */
  *sum += partial_sum;
  spin_unlock(&semid);


  printf("\nprocess pid = %d, partial sum = %d\n", pid, partial_sum);
  if (pid == 0) exit(0); else wait(0);            /* terminate child process */
  printf("\nthe sum of 1 to %i is %d\n", array_size, *sum);


  if (semctl(semid, 0, IPC_RMID, 1) == -1) {      /* remove semaphore */
    perror("semctl");
    exit(1);
  }
  if (shmctl(shmid, IPC_RMID, NULL) == -1) {         /* remove shared memory */
    perror("shmctl");
    exit(1);
  }
  exit(0);

}

