/*****
 * demo of the RSA algorithm.
 * 
 * I'm *not* try to be secure or realistic here,
 * but just to show the general method.
 *
 ****/

// hit long overflow in encode with p-617, q=503


#include <string.h>
#include <stdio.h>

long P = 131;  // these are both prime.  See bigint/low_primes.h
long Q = 89;  
long N;
long B;
long E;
long D;

// ----------------------------
void init() {
  long i;

  N = P*Q;
  B = (P-1)*(Q-1);
  // find E
  for (i=3; i<P; i+=2){
    if ( B % i != 0 ) {
      // printf(" Found E=%d \n", i);
      break;
    }
  }
  E = i;

  // find D
  for (i=1; i<N; i++){
    if ( (B*i + 1) % E == 0 ){
      D = (B*i+1)/E;
      // printf(" Found D=%d \n", i);
      break;
    }
  }
}

// ----------------------------
void testLongSize(){
  long n;
  int i;
  printf( "sizeof(long)=%d \n", sizeof(long));
  n=1;
  for (i=0;i<31;i++){
    n = 2*n + 1;   
    //    or, equivalently
    // n = (n<<1) | 1 ;
  }
  printf( " a big long = 2**%d-1 = %d base 10 = 0x%x hex \n", i+1,n,n);
}

// ----------------------------
void showConstants(){
  printf(" P= %d, Q=%d \n N=PQ=%d=0x%x \n B=(P-1)(Q-1)=%d \n", P,Q,N,N,B);
  printf(" E=%d, D=%d \n",E,D);
  printf(" B mod E = %d  <<== should be non-zero \n", B % E);
  printf(" ED mod B = %d <<== should be 1 \n", (E*D) % B);
  printf("\n");
  printf(" Public key: (%d,%d) \n", E,N);
  printf(" Private key: (%d,%d) \n", D,N);
  printf("\n");
}


// ----------------------------
// return a**b mod c
// This will fail if we overflow the size of a longint.
long powermod( long a, long b, long c){
  long answer=1;
  long temp;
  long i;
  //printf(" Finding %d^%d mod %d ... ", a,b,c);
  for (i=0; i<b; i++){
    temp = answer*a;
    if (temp<0) {
      printf(" OOPS - temp went < 0 after multipy \n");
      printf(" a=%d answer=%d, temp=answer*a=%d \n",a,answer,temp);
      exit(1);
    }
    temp = temp % c;
    if (temp<0) {
      printf(" OOPS - temp went < 0 after mod \n");
      exit(1);
    }
    answer = temp;
  }
  // printf(" = %d \n", answer);
  return answer;
}



// ----------------------------
int main(){
  
  long M;
  long C;

  // testLongSize();

  init();
  showConstants();

  // Now do mod arithmetic to encode, namely
  // (1) Figure out how many bits can go in M, the message.
  // (2) Break message into chunks of that many bits.
  //     On each chunk,
  // (3) Calculate (M**E mod N) by  M = (M*M mod N) done N times.
  //      (This is O(N), while binary square and multiply 
  //       is O(log2(N)), but it'll get the job done here and is simple.)
  // (4) Output C = M**E mod N, has same number of bits in chunk as input.

  // Decode is the same thing but with M = (C**D mod N).

  M=1234; // message
  printf(" message=%d=0x%x \n", M,M);

  C = powermod(M,E,N);
  printf(" encoded=%d \n", C);

  M = powermod(C,D,N);
  printf(" decoded=%d \n", M);

}



syntax highlighted by Code2HTML, v. 0.9.1