/*******
 * dct_matrix.c
 *
 * Looking at the matrix for a Discrtete Cosine Transform
 *
 *
 *******/

#include <stdio.h>
#include <math.h>

// Number of points.  JPEG's typically used 8x8 2-dim DCT's.
#define N 8

// ------------------------------------------------------
// Define the 64 terms in the DCT and DCT_inverse matrix.
void defineDCT(double dct[N][N]){
  double pi = 4.0*atan(1.0);	// typical way to get pi, using tan(pi/4)=1
  double piOver2N = pi/(2*N);
  double invSqrt2 = 1.0/sqrt(2.0);
  double sqrt2overN = sqrt(2.0/N);
  double C;
  int u,x;
  for (u=0; u<N; u++){
    C = u==0 ? invSqrt2 : 1;
    for (x=0; x<N; x++) {
      dct[u][x] = sqrt2overN*C*cos( (2*x+1)*u*piOver2N );
    }
  }
}

// ------------------------------------------------------
// The character c is placed at the start of each line.
void printMatrix(double m[N][N], char c){
  int row,col;
  for (row=0; row<N; row++){
    printf("%c",c);
    for (col=0; col<N; col++) {
      printf(" % 7.5f", m[row][col]);
    }
    printf("\n");
  }
}

// ------------------------------------------------------
// a*b=c.  All memory must be allocated ahead of time.
void multiplyMatrices( double a[N][N], double b[N][N], double c[N][N]){
  int row,col,vec;
  double tmp;
  for (row=0; row<N; row++){
    for (col=0; col<N; col++) {
      tmp=0.0;
      for (vec=0; vec<N; vec++){
	tmp += a[row][vec] * b[vec][col];
      }
      c[row][col]=tmp;
    }
  }
}

// ------------------------------------------------------
// The DCT matrix is orthogonal, so its transpose is also its inverse.
void transposeMatrix(double a[N][N], double aTrans[N][N]){
  int row,col;
  for (row=0; row<N; row++){
    for (col=0; col<N; col++){
      aTrans[col][row] = a[row][col];
    }
  }
}

// ------------------------------------------------------
int main(){

  double dct[N][N];		// All these are [row][column]
  double dctInv[N][N];
  double product[N][N];

  defineDCT(dct);			
  transposeMatrix(dct, dctInv);		
  
  printf("# Here's the matrix for discrete cosine transform.\n\n");
  printf("# The rows are x=0..(N-1), columns are u=0..(N-1)\n");
  printf("# The transform and its inverse are given by \n");
  printf("#  S(u) = sum_x { f(x)* C(u)*sqrt(2/N)*cos((2x+1)*u*pi/(2*N)) } \n");
  printf("#  f(x) = sum_u { S(u)* C(u)*sqrt(2/N)*cos((2x+1)*u*pi/(2*N)) } \n");
  printf("# where C(u)=1/sqrt(2) if u=0 and C(u)=1 if u is not 0.\n");
  printf("# Using N=%d.\n",N);
  printf("#\n");
  printMatrix(dct, '#');

  printf("#\n");
  printf("# And here's the transpose: \n\n");
  printMatrix(dctInv, ' ');
  printf("\n");

  printf("# The matrix product of DCT * transpose(DCT) is \n");
  multiplyMatrices(dct, dctInv, product);
  printMatrix(product, '#');

}



syntax highlighted by Code2HTML, v. 0.9.1