//                               GRAVTUT.CPP

#include <stdio.h>
#include <graphics.h>
#include <stdlib.h>
#include <conio.h>
#include <math.h>
#include <ctype.h>

#define G 0.0006673          //universal gravitational constant
#define AUtoSDU 14.9598      //convert AU to simulation distance unit
#define YRtoSTU 3147.0336    //convert years to simulation time unit
#define maxp 3               //maximum number of particles
#define sccx 320             //x-coordinate of center of screen
#define sccy 240             //y-coordinate of center of screen
#define zinc 1.33352         //zoom step factor
#define sinc 1.5             //shift increment
#define rinc 5 * 0.0174532   //rotation increment angle

class simulation
{
  public:
  simulation(int snump);
  ~simulation();
  void core(void);
  void input();                //finds user input and routes appropriately
  void gtrans(int *graphx, int *graphy, float x, float y, float z);
  void drawparticles(int gorder, int style);  //draw or erase all particles
  void drawvelocityvector(int p, int c);
  void pause(void);                //holds simulation during user pause
  void imagezoom(int zdir);        //0 = zoom in;  1 = zoom out
  void changedelay(int ddir);      //increases or decreases delay
  void displayinfo(void);          //displays simulation statistics
  void clearinfo(void);            //clears displayinfo() display
  void imagerotate(int rdir);      //rotates image

  void addobject(double x, double y, double z, double vx, double vy, double vz, double m, int c);
  void moveobject(int obj, double dx, double dy, double dz);
  void accobject(int obj, double dvx, double dvy, double dvz);
  void flush();

  void opengraphicsmode(char *mode);
  void runsim();

  long iter;           //iteration number
  double t;            //time (STU)
  double timestep;     //time increment (STU=10^4s)
  int nump;            //number of particles
  char com;            //user input variable
  float centx;         //x-coordinate of effective center of screen
  float centy;         //y-coordinate of effective center of screen
  float centz;         //z-coordinate of effective center of screen
  float offx;          //x-axis distance from center and tracked particle
  float offy;          //y-axis distance from center and tracked particle
  float rota;          //x-y rotation
  float rotb;          //x-z rotation
  float rotc;          //y-z rotation
  float zoom;          //zoom factor
  int ptrack;          //tracked particle ID
  int exitflag;        //flags exit command
  int pauseflag;       //flags pause
  int displayflag;     //flags display info
  long delay;          //steps in timing loop delay
  int tracers;         //flags tracers
  int particlestyle;   //appearance of particles 0=point 1=cross 2=circle
  int drawfreq;        //redraw particles every drawfreq timesteps

  double px[maxp];         //x-coordinate
  double py[maxp];         //y-coordinate
  double pz[maxp];         //z-coordinate
  double pvx[maxp];        //x-axis velocity
  double pvy[maxp];        //y-axis velocity
  double pvz[maxp];        //z-axis velocity
  double pm[maxp];         //mass
  int pc[maxp];           //pixel color
  int pc2[maxp];          //tracer color
  int pgx[maxp];        //graphical x-coordinate
  int pgy[maxp];        //graphical y-coordinate
};

simulation::simulation(int dummy)
{
   nump = 0;
   timestep = 0.125;
   centx = centy = centz = 0;
   offx = offy = -0.01;
   rota = rotb = rotc = 0;
   zoom = 5;
   ptrack = 0;
   pauseflag = exitflag = displayflag = 0;
   delay = 0;
   tracers = 1;
   particlestyle = 2;
   drawfreq = 2;
}

simulation::~simulation()
{

}

void simulation::core(void)
{
   //Use Euler's method to update the particle positions after a duration
   // equal to one timestep.

   int p1, p2;                //particle IDs
   double dx, dy, dz;         //distance between particles along axes
   double D;                  //resultant distance between particles
   double A;                  //resultant acceleration
   double pax, pay, paz;      //aceleration along each axis

   for (p1 = 0; p1 < nump; p1 ++)
   {
      pax = pay = paz = 0.0;
      for (p2 = 0; p2 < nump; p2 ++)            //determine acceleration
      {
	 if (p1 != p2)
	 {
	    dx = px[p2] - px[p1];
	    dy = py[p2] - py[p1];
	    dz = pz[p2] - pz[p1];
	    D = sqrt(dx*dx + dy*dy + dz*dz);
	    A = G * pm[p2] / (D*D);
	    pax += dx * A / D;
	    pay += dy * A / D;
	    paz += dz * A / D;
	 }
      }
      pvx[p1] += pax * timestep;            //accelerate
      pvy[p1] += pay * timestep;
      pvz[p1] += paz * timestep;
   }

   for (p1 = 0; p1 < nump; p1++)
   {
      px[p1] += pvx[p1] * timestep;          //move
      py[p1] += pvy[p1] * timestep;
      pz[p1] += pvz[p1] * timestep;
   }

   //bad collision detection
   for (p1 = 1; p1 < nump; p1++)
   {
     if ((px[p1]*px[p1] + py[p1]*py[p1] + pz[p1]*pz[p1]) < 0.075)
      {px[p1] = 1e+6; py[p1] = 1e+6; pvx[p1] = 0.01; pvy[p1] = 0.01;}
   }

   return;
}

/*******************************************************************/
void simulation::input()
{
   com = getch();
   if (com == 27) exitflag = 1; else exitflag = 0;
   if ((com == '-') || (com == '_')) changedelay(0);
   if ((com == '+') || (com == '=')) changedelay(1);
   /*
   if ((com == 'i') || (com == 'I')) imagezoom(0);
   if ((com == 'o') || (com == 'O')) imagezoom(1);
   if ((com == ',') || (com == '<')) imagerotate(0);
   if ((com == '.') || (com == '>')) imagerotate(1);
   if ((com == 'l') || (com == 'L')) imagerotate(2);
   if ((com == ';') || (com == ':')) imagerotate(3);
   if  (com == 'p')                  imagerotate(4);
   if ((com == '[') || (com == '{')) imagerotate(5);
   if ((com == 'c') || (com == 'C')) {clearviewport(); drawparticles(1, particlestyle);}
   */
   if ((com == ' ') && (!pauseflag)) pause();
   /*
   if ((com == '?') || (com == '/')) {pause(); displayinfo();}
   if ((com == 't') || (com == 'T')) tracers = !tracers;
   if (com == 'r') {timestep = -timestep;}
   if (com >= '0' && com <= '0'+nump && com <= '9')
   {  if (com != 1)
      clearviewport();
      ptrack = com - '0';
      drawparticles(1, particlestyle);
   }
   if (com == '`') {clearviewport(); drawparticles(1,particlestyle); ptrack = -1;}
   */
   return;
}
/*******************************************************************/
void simulation::gtrans(int *graphx, int *graphy, float x, float y, float z)
{
   float transx0, transy0, transz0;
   float transx1, transy1, transz1;
   float transx2, transy2, transz2;

   transx0 = x - centx;
   transy0 = y - centy;
   transz0 = z - centz;

   transx1 = transx0 * cos(rota) + transy0 * sin(rota);
   transy1 = transy0 * cos(rota) - transx0 * sin(rota);

   transy2 = transy1 * cos(rotc) + transz0 * sin(rotc);
   transz1 = transz0 * cos(rotc) - transy1 * sin(rotc);

   //transx2 = transx1 * cos(rotb) + transz0 * sin(rotb);
   //transz1 = transz0 * cos(rotb) - transx1 * sin(rotb);

   transx2 = transx1; //transy2 = transy1;

   *graphx = (int)(zoom * (transx2) + sccx + offx);
   *graphy = (int)(zoom * (transy2) + sccy + offy);

   return;
}

/*******************************************************************/
void simulation::drawparticles(int gorder, int style)
{
   int p1;
   int color;
   int oldgx, oldgy;
   int newgx, newgy;


   for (p1 = 0; p1 < nump; p1++)
   {
     if (pm[p1] > 1) style = 3;
     if (pm[p1] < 1) style = 2;

     oldgx = pgx[p1]; oldgy = pgy[p1];
     gtrans(&newgx,&newgy, px[p1], py[p1], pz[p1]);

     if (oldgx == newgx && oldgy == newgy && gorder < 2) continue;

     if (gorder >= 1)
     {
	if (ptrack >= 0) {centx = px[ptrack]; centy = py[ptrack]; centz = pz[ptrack];}    //track

	pgx[p1] = newgx;  pgy[p1] = newgy;

	color = pc[p1];
	if (getbkcolor() == 15)
	{
	  if (color == 15) color = 0;
	  //if (color == 14)) color = 6;
	  if (color == 7) color = 0;
	}
      }
      if (gorder == 0) color = pc2[p1];
      if ((gorder == -1) || (gorder == 0) && (tracers == 0)) color = 0;
      if (gorder == -2) color = 15;

      //draw
      if (style == 0) putpixel (pgx[p1], pgy[p1], color);
      if (style == 1)
      {
	 putpixel (pgx[p1], pgy[p1], color);
	 putpixel (pgx[p1] + 1, pgy[p1], color);
	 putpixel (pgx[p1] - 1, pgy[p1], color);
	 putpixel (pgx[p1], pgy[p1] + 1, color);
	 putpixel (pgx[p1], pgy[p1] - 1, color);
      }
      if (style >= 2)
      {
	 setcolor(color);
	 setfillstyle(1, color);
	 circle(pgx[p1], pgy[p1], 2 + (style-2));
	 floodfill(pgx[p1]-1, pgy[p1]-1, color);
      }
   }
   return;
}

/*******************************************************************/
void simulation::drawvelocityvector(int p, int c)
{
   setcolor(c);
   int gx, gy;
   double vtipxpl, vtipypl;
   int vtipx, vtipy;
   int varr1x, varr1y;
   int varr2x, varr2y;
   double magfactor = 300.0;

   vtipxpl = px[p]+pvx[p]*magfactor;
   vtipypl = py[p]+pvy[p]*magfactor;

   //I got lazy here.
   gtrans(&gx, &gy, px[p], py[p], pz[p]);
   gtrans(&vtipx, &vtipy, vtipxpl, vtipypl, pz[p]);
   gtrans(&varr1x, &varr1y, vtipxpl - pvx[p]*magfactor/4, vtipypl - 0.5*sqrt(pvx[p]*magfactor), pz[p]);
   gtrans(&varr2x, &varr2y, vtipxpl - pvx[p]*magfactor/4, vtipypl + 0.5*sqrt(pvx[p]*magfactor), pz[p]);

   line(gx,gy,vtipx,vtipy);
   line(vtipx,vtipy,varr1x,varr1y);
   line(vtipx,vtipy,varr2x,varr2y);
   return;
}


/*******************************************************************/
void simulation::pause(void)
{
   pauseflag = 1;
   do
   {
      if (displayflag == 1) displayinfo();
      input();
   } while ((com != ' ') && (exitflag == 0));

   if ((displayflag >= 1) && (exitflag == 0)) clearinfo();
   if (tracers == 2) tracers = 1;
   pauseflag = 0;
   displayflag = 0;
   return;
}

/*******************************************************************/
void simulation::imagezoom(int zdir)
{
   if ((!zdir) && (zoom < 1000)) zoom *= zinc;
   if ((zdir) && (zoom > 0.01)) zoom /= zinc;
   clearviewport();
   drawparticles(1, particlestyle);
   return;
}

/*******************************************************************/
void simulation::changedelay(int ddir)
{
   if (ddir)
   {
      if ((delay > 0) && (delay < 4000000)) delay *= 2;
      if (delay == 0) delay = 100;
   }
   else
   {
      if (delay > 100) delay /= 2;
      if (delay == 100) delay = 0;
   }
   return;
}
/*******************************************************************/
void simulation::displayinfo(void)
{
   if (timestep >= 1)
   {
      gotoxy(1, 1); printf("    time: %.0f      ", t);
      gotoxy(1, 3); printf("timestep: %.0f      ", timestep);
   }
   else
   {
      gotoxy(1, 1); printf("    time: %.3f      ", t);
      gotoxy(1, 3); printf("timestep: %.3f      ", timestep);
   }
   gotoxy(1, 2); printf("          %.3f yr   ", t * 0.00031688);
   gotoxy(1, 4); printf("   delay: %ld       ", delay);
   gotoxy(1, 5); printf("    zoom: %.0f\%    ", zoom * 100);
   gotoxy(1, 6); printf("x center: %.2f      ", centx);
   gotoxy(1, 7); printf("y center: %.2f      ", centy);
   gotoxy(1, 8); printf("a rotate: %.0f      ", rota / 0.0174532);
   //gotoxy(1, 9); printf("b rotate: %.0f      ", rotb / 0.0174532);
   gotoxy(1, 9); printf("c rotate: %.0f      ", rotc / 0.0174532);
   gotoxy(1, 10); printf("tracking: %d        ", ptrack);
   //gotoxy(1, 11); printf("drawfreq: %d        ", drawfreq);
   return;
}
/*******************************************************************/
void simulation::clearinfo(void)
{
   int n;
   gotoxy(1,1);
   if (displayflag == 1)
     for (n = 1; n <= 12; n++) printf("                   \n");
   if (displayflag == 2)
     for (n = 1; n <= 17; n++) printf("                        \n");

   return;
}

/*******************************************************************/
void simulation::imagerotate(int rdir)
{
   if (tracers == 1) {tracers = 2; clearviewport();}
   else if (!pauseflag) clearviewport();

   if (rdir == 0) rota += rinc;
   if (rdir == 1) rota -= rinc;
   if (rdir == 2) rotb += rinc;
   if (rdir == 3) rotb -= rinc;
   if (rdir == 4) rotc += rinc;
   if (rdir == 5) rotc -= rinc;

   if (rota < -179 * 0.0174532) rota = 180 * 0.0174532;
   if (rotb < -179 * 0.0174532) rotb = 180 * 0.0174532;
   if (rotc < -179 * 0.0174532) rotc = 180 * 0.0174532;

   if (rota > 184 * 0.0174532) rota = -175 * 0.0174532;
   if (rotb > 184 * 0.0174532) rotb = -175 * 0.0174532;
   if (rotc > 184 * 0.0174532) rotc = -175 * 0.0174532;

   drawparticles(-1, particlestyle);
   drawparticles(2, particlestyle);

   return;
}

/********************************************************/

void simulation::addobject(double x, double y, double z, double vx, double vy,
			   double vz, double m, int c)
{
  //gotoxy(1,nump+1);
  //printf("Adding particle %d", nump);
  px[nump] = x;  py[nump] = y;  pz[nump] = z;
  pvx[nump] = vx; pvy[nump] = vy; pvz[nump] = vz;
  pm[nump] = m; pc[nump] = c; pc2[nump] = 8;
  pgx[nump] = pgy[nump] = -1;
  if (nump < maxp - 1) nump += 1;
  return;
}

void simulation::moveobject(int obj, double dx, double dy, double dz)
{
  px[obj] += dx;
  py[obj] += dy;
  pz[obj] += dz;
}

void simulation::accobject(int obj, double dvx, double dvy, double dvz)
{
  pvx[obj] += dvx;
  pvy[obj] += dvy;
  pvz[obj] += dvz;
}

void simulation::flush()
{
  nump = 0;
}

void simulation::opengraphicsmode(char *mode)
{
   int gdriver, gmode, errorcode;
   gdriver = DETECT;
   initgraph(&gdriver, &gmode, mode);
}

void simulation::runsim()
{
   t = 0.0;
   iter = 0;
   long n;

   //clearviewport();
   drawparticles(2,2);
   exitflag = 0;

   gotoxy(7, 1); printf("t =");
   gotoxy(1, 3); printf("spc Pause");
   gotoxy(1, 4); printf("esc New parameters");

   do
   {
      if (iter % drawfreq == 0)
      {
	 //gotoxy(11,1); printf("%.0f x 10^4 sec", t);
	 gotoxy(11,1); printf("%.3f yr", t * 0.00031688);
	 drawparticles(0, particlestyle);
	 drawparticles(1, particlestyle);
      }

      core();

      for (n = 1; n < delay; n ++);             //delay

      if (kbhit()) input();

      if (!exitflag) {iter++; t += timestep;}

   } while (!exitflag);

   gotoxy(38, 1); printf("simulation terminated at %.0f", t);
   gotoxy(1, 3); printf("         ");
   gotoxy(1, 4); printf("                  ");

}

void box(void)
{
      line(0,0,639,0); line(0,0,0,479); line(639,0,639,479); line(0,479,639,479);
}

int main(void)
{
   char com = 0;

   double startpos = 15.0;
   double startvel = 0.030;

   simulation tutorial(0);

   tutorial.opengraphicsmode("EGAVGA.BGI");

   while(com != 27)
   {
      tutorial.flush();
      tutorial.addobject(0,0,0,        0,0,0,        20,    14);
      tutorial.addobject(0,startpos,0, startvel,0,0, 0.00006, 11);
      tutorial.tracers = 0;
      clearviewport();
      tutorial.drawparticles(2,2);
      com = 0;

      gotoxy(1,1); printf(" %c Decrease separation", 24);
      gotoxy(1,2); printf(" %c Increase separation", 25);
      gotoxy(1,3); printf("<- Decrease velocity");
      gotoxy(1,4); printf("-> Increase velocity");
      gotoxy(1,5); printf(" d Defaults (Earth)");
      gotoxy(1,6); printf("spc Start!");
      gotoxy(1,7); printf("esc Quit");

      while(com != 27 && com != ' ')
      {
	 com = 0;
	 tutorial.drawparticles(0,2);
	 tutorial.drawparticles(1,2);
	 tutorial.drawvelocityvector(1,7);

	 gotoxy(36,1); printf("initial separation = %.0f x 10^6 km ", 10 * tutorial.py[1]);
	 gotoxy(36,2); printf("  initial velocity = %.1f km/s ", tutorial.pvx[1] * 1000);

	 //gotoxy(20,4); printf("%f %f %f...%f %f %f", tutorial.px[1], tutorial.py[1], tutorial.pz[1], tutorial.pvx[1], tutorial.pvy[1], tutorial.pvz[1], tutorial.pc[1]);
	 //gotoxy(2,2); printf("%d", tutorial.nump);

	 com = getch();

	 if (com >= 72 && com <= 80 || com == 'd') tutorial.drawvelocityvector(1,0);
	 if (com == 72) tutorial.moveobject(1, 0,-1,0);
	 if (com == 75) tutorial.accobject(1, -0.0001,0,0);
	 if (com == 77) tutorial.accobject(1, 0.0001,0,0);
	 if (com == 80) tutorial.moveobject(1, 0,1,0);
	 if (com == 'd') {tutorial.py[1] = 14.9; tutorial.pvx[1] = 0.030;}
	 if (tutorial.pvx[1] < 0.005) tutorial.pvx[1] = 0.005;
	 if (tutorial.py[1] < 3) tutorial.py[1] = 3;
      }

      if (com == ' ')
      {
	 for (int l = 1; l <= 7; l++)
	 {
	   gotoxy(1,l); printf("                      ");
	 }

	 startpos = tutorial.py[1];
	 startvel = tutorial.pvx[1];
	 tutorial.drawvelocityvector(1,0);
	 tutorial.tracers = 1;
	 tutorial.runsim();

      }
   }

   closegraph();

   return 0;
}
