numeric-lib/matrix.h
00001 /*
00002  * matrix.h
00003  *
00004  *  Created on: Dec 6, 2008
00005  *      Author: tdillig
00006  *
00007  *  This class allows for infinite precision matrix computations.
00008  *  To increase efficiency, the matrix "upgrades" entire rows at a time to
00009  *  GNU MP values rather than representing each indivdual element as a bignum.
00010  */
00011 
00012 #ifndef MATRIX_H_
00013 #define MATRIX_H_
00014 
00015 #include <stdlib.h>
00016 #include <string.h>
00017 #include <gmp.h>
00018 #include "bignum.h"
00019 #include <assert.h>
00020 #include <vector>
00021 using namespace std;
00022 
00023 
00024 #include <map>
00025 #include <set>
00026 
00027 
00028 typedef long int d_type;
00029 
00030 
00031 
00032 
00033 
00034 typedef d_type v4si __attribute__ ((vector_size (32)));
00035 
00036 typedef d_type v2si __attribute__ ((vector_size (16)));
00037 
00038 /*
00039  * Should operations on rows use SSE vector (SIMD) instructions?
00040  */
00041 #define VECTORIZE_TWO true
00042 
00043 //#define VECTORIZE_FOUR true
00044 
00045 
00046 
00047 
00048 
00049 #define MD(rr,cc) (((d_type*)(dmatrix+(rr*mcols)))[cc])
00050 
00051 
00052 //#define MD(rr,cc)((dmatrix[rr*mcols+cc]).d)
00053 #define MI(rr,cc)((dmatrix[rr*mcols+cc]).i)
00054 
00055 
00056 /*
00057  * Enables bounds checking for each operation. Since most checks are
00058  * amortized over a row operation, the expected slow-down is ~5%.
00059  */
00060 #define CHECK_BOUNDS false
00061 
00062 //#define ENABLE_CHECKS true
00063 
00064 #define DWORD_ALIGN(x) (((((x)+1)+3)/4)*4)
00065 
00066 
00067 class matrix {
00068         friend class slack_matrix;
00069 
00070 protected:
00071         data_type *dmatrix;
00072         const int rows;
00073         const int cols;
00074         const int mcols;
00075         bool *big_rows;
00076         vector<string> *vars;
00077 
00078 /*
00079  * If checking is enabled, we compare the matrix with a matrix of
00080  * bignums.
00081  */
00082 #ifdef ENABLE_CHECKS
00083         bignum* cmatrix;
00084 #endif
00085 
00086 public:
00087         /*
00088          * Constructs an infinite precision matrix of size rows x num_vars+2 and
00089          * initializes all values to 0.
00090          */
00091         inline matrix(int rows, int num_vars,
00092                         vector<string> *vars):
00093                                 rows(rows), cols(num_vars+2), mcols(DWORD_ALIGN(cols+1)),
00094                                 vars(vars)
00095         {
00096                 dmatrix = new data_type[(rows+2)*mcols];
00097                 big_rows = new bool[rows];
00098                 memset(big_rows, 0, rows*sizeof(bool));
00099                 memset(dmatrix, 0, rows*mcols*sizeof(data_type));
00100 
00101 
00102 #ifdef ENABLE_CHECKS
00103                 cmatrix = new bignum[(rows+2)*mcols];
00104                 for(int r=0; r<rows; r++) {
00105                         infinitize_row(r);
00106                 }
00107 #endif
00108 
00109         }
00110 
00111         inline matrix(int rows, int cols):
00112                                         rows(rows), cols(cols), mcols(DWORD_ALIGN(cols+1))
00113         {
00114                 dmatrix = new data_type[(rows+2)*mcols];
00115                 big_rows = new bool[rows];
00116                 memset(big_rows, 0, rows*sizeof(bool));
00117                 memset(dmatrix, 0, rows*mcols*sizeof(data_type));
00118                 vars = NULL;
00119         }
00120 
00121 
00122         inline int num_rows()
00123         {
00124                 return rows;
00125         }
00126 
00127 
00128         inline int num_vars()
00129         {
00130                 return cols-2;
00131         }
00132         inline int num_cols()
00133         {
00134                 return cols;
00135         }
00136 
00137         /*
00138          * Creates a duplicate of other matrix. Takes an (optional) argument
00139          * specifying how many rows should be added.
00140          */
00141         inline matrix(const matrix& other, int new_rows=-1):
00142                 rows(new_rows==-1 ? other.rows : other.rows+new_rows),
00143         cols(other.cols), mcols(other.mcols),vars(other.vars)
00144         {
00145 
00146 
00147 
00148                 dmatrix = new data_type[(rows+2)*mcols];
00149                 big_rows = new bool[rows];
00150                 if(new_rows!=-1){
00151                         memset(big_rows, 0, rows*sizeof(bool));
00152                         memset(dmatrix, 0, rows*mcols*sizeof(data_type));
00153 
00154                 }
00155                 memcpy(big_rows, other.big_rows, rows*sizeof(bool));
00156                 for(int r=0; r < rows; r++) {
00157                         if(!big_rows[r]) {
00158                                 memcpy(dmatrix+r*mcols, other.dmatrix+r*mcols,
00159                                                 cols*sizeof(data_type));
00160                                 continue;
00161                         }
00162                         int end = r*mcols+cols;
00163                         for(int i = r*mcols; i < end; i++) {
00164                                 mpz_init_set(dmatrix[i].i, other.dmatrix[i].i);
00165                         }
00166 
00167                 }
00168 #ifdef ENABLE_CHECKS
00169                 cmatrix = new bignum[(rows+2)*mcols];
00170                 for(int r = 0; r < rows; r++)
00171                 {
00172                         for(int c=0; c < cols; c++)
00173                         {
00174                                 cmatrix[r*mcols+c] = other.cmatrix[r*mcols+c];
00175                         }
00176                 }
00177 #endif
00178         }
00179 
00180         /*
00181          * Simplifies all entries in the matrix.
00182          */
00183         inline void simplify_matrix()
00184         {
00185                 for(int r=0; r < rows; r++) {
00186                         simplify_row(r);
00187                 }
00188         }
00189 
00190         /*
00191          * Sets matrix[r][c] = v.
00192          */
00193         inline void set(int r, int c, bignum  v)
00194         {
00195 
00196 
00197                 if(CHECK_BOUNDS){
00198                         assert(r>=0 && r<rows);
00199                         assert(c>=0 && c<cols);
00200                 }
00201 #ifdef ENABLE_CHECKS
00202                 cmatrix[r*mcols+c] = v;
00203 #endif
00204 
00205                 if(!v.infinite && !big_rows[r]){
00206                         //dmatrix[r*mcols+c].d = v.data.d;
00207                         MD(r, c) = v.data.d;
00208                         return;
00209                 }
00210                 if(!big_rows[r]){
00211                         infinitize_row(r);
00212                 }
00213                 if(v.infinite){
00214                         //mpz_set(dmatrix[r*mcols+c].i, v.data.i);
00215                         mpz_set(MI(r,c), v.data.i);
00216                 }
00217                 else {
00218                         //mpz_set_si(dmatrix[r*mcols+c].i, v.data.d);
00219                         mpz_set_si(MI(r, c), v.data.d);
00220                 }
00221         }
00222 
00223 
00224         /*
00225          * Sets matrix[r][c] = v.
00226          */
00227         inline void set(int r, int c, long int i)
00228         {
00229                 if(CHECK_BOUNDS){
00230                         assert(r>=0 && r<rows);
00231                         assert(c>=0 && c<cols);
00232                 }
00233 #ifdef ENABLE_CHECKS
00234                 cmatrix[r*mcols+c] = bignum(i);
00235 #endif
00236                 if(!big_rows[r]){
00237                         //dmatrix[r*mcols+c].d = i;
00238                         MD(r,c) = i;
00239                         return;
00240                 }
00241                 //mpz_set_si(dmatrix[r*mcols+c].i, i);
00242                 mpz_set_si(MI(r, c), i);
00243 
00244         }
00245         /*
00246          * Returns matrix[r][c].
00247          */
00248         inline bignum get(int r, int c)
00249         {
00250                 if(CHECK_BOUNDS){
00251                         assert(r>=0 && r<rows);
00252                         assert(c>=0 && c<cols);
00253                 }
00254                 if(!big_rows[r]){
00255                         //bignum bg(dmatrix[r*mcols+c].d);
00256                         bignum bg(MD(r, c));
00257                         return bg;
00258                 }
00259                 //return bignum(dmatrix[r*mcols+c].i);
00260                 return bignum(MI(r, c));
00261 
00262         }
00263 
00264         /*
00265          * Multiplies row r by f.
00266          */
00267         inline void multiply_row(int r, bignum f)
00268         {
00269                 if(CHECK_BOUNDS){
00270                         assert(r>=0 && r<rows);
00271                 }
00272 #ifdef ENABLE_CHECKS
00273                 for(int i=r*mcols; i<r*mcols+cols-1; i++) {
00274                         cmatrix[i]*=f;
00275                 }
00276 #endif
00277 
00278 
00279                 if(big_rows[r]){
00280                         multiply_row_i(r, f);
00281                         return;
00282                 }
00283                 if(f.infinite){
00284                         infinitize_row(r);
00285                         multiply_row_i(r, f);
00286                         return;
00287                 }
00288 
00289 
00290 
00291                 long int max = labs(MD(r, 0));
00292                 /*for(int i=r*mcols+1; i<end; i++) {
00293                         if(labs(dmatrix[i].d)>max){
00294                                 max = labs(dmatrix[i].d);
00295                         }
00296                 }*/
00297                 for(int i=1; i<cols-1; i++) {
00298                         if(labs(MD(r, i))>max){
00299                                 max = labs(MD(r, i));
00300                         }
00301                 }
00302                 if(bignum::m_overflow(max, f.data.d)){
00303                         infinitize_row(r);
00304                         multiply_row_i(r, f);
00305                         return;
00306                 }
00307 
00308                 /*for(int i=r*mcols; i<end; i++) {
00309                         dmatrix[i].d*=f.data.d;
00310                 }*/
00311 #ifdef VECTORIZE_TWO
00312                 {
00313                         v2si c = {f.data.d, f.data.d};
00314                         d_type last = MD(r, cols-1);
00315                         v2si* cur = (v2si*)&(MD(r, 0));
00316                         v2si* end = (v2si*)&(MD(r, cols-1));
00317                         for(; cur <end; cur++) {
00318 
00319                                 (*cur)*=c;
00320                                 //MD(r, i)*=f.data.d;
00321                         }
00322                         MD(r, cols-1) = last;
00323                 }
00324 #endif
00325 
00326 #ifdef VECTORIZE_FOUR
00327                 {
00328                         v4si c = {f.data.d, f.data.d, f.data.d, f.data.d};
00329                         d_type last = MD(r, cols-1);
00330                         v4si* cur = (v4si*)&(MD(r, 0));
00331                         v4si* end = (v4si*)&(MD(r, cols-1));
00332                         for(; cur <end; cur++) {
00333 
00334                                 (*cur)*=c;
00335                                 //MD(r, i)*=f.data.d;
00336                         }
00337                         MD(r, cols-1) = last;
00338                 }
00339 #endif
00340 
00341 
00342 #ifndef VECTORIZE_TWO
00343 #ifndef VECTORIZE_FOUR
00344                 else
00345                 {
00346                         for(int i=0; i<cols-1; i++) {
00347                                 MD(r, i)*=f.data.d;
00348                         }
00349                 }
00350 #endif
00351 #endif
00352 
00353         }
00354 
00355         /*
00356          * Multiplies row r by f.
00357          */
00358         inline void flip_row_sign(int r)
00359         {
00360                 if(CHECK_BOUNDS){
00361                         assert(r>=0 && r<rows);
00362                 }
00363 #ifdef ENABLE_CHECKS
00364                 for(int i=r*mcols; i<r*mcols+cols-1; i++) {
00365                                 cmatrix[i]=-cmatrix[i];
00366                 }
00367 #endif
00368 
00369                 if(big_rows[r]){
00370                         //for(int i=r*mcols; i < end; i++) {
00371                         //      mpz_neg(dmatrix[i].i, dmatrix[i].i);
00372                         //}
00373                         for(int i=0; i < cols-1; i++) {
00374                                 mpz_neg(MI(r, i), MI(r, i));
00375                         }
00376                         return;
00377                 }
00378                 //for(int i=r*mcols; i < end; i++) {
00379                 //      dmatrix[i].d=-dmatrix[i].d;
00380                 //}
00381 
00382 #ifdef VECTORIZE_TWO
00383                 {
00384                         d_type last = MD(r, cols-1);
00385                         v2si* cur = (v2si*)&(MD(r, 0));
00386                         v2si* end = (v2si*)&(MD(r, cols-1));
00387                         for(; cur <end; cur++) {
00388 
00389                                 (*cur)=-*cur;
00390                                 //MD(r, i)*=f.data.d;
00391                         }
00392                         MD(r, cols-1) = last;
00393                 }
00394 #endif
00395 
00396 #ifdef VECTORIZE_FOUR
00397                 {
00398                         d_type last = MD(r, cols-1);
00399                         v4si* cur = (v4si*)&(MD(r, 0));
00400                         v4si* end = (v4si*)&(MD(r, cols-1));
00401                         for(; cur <end; cur++) {
00402 
00403                                 (*cur)=-*cur;
00404                                 //MD(r, i)*=f.data.d;
00405                         }
00406                         MD(r, cols-1) = last;
00407                 }
00408 #endif
00409 
00410 
00411 
00412 
00413 
00414 #ifndef VECTORIZE_TWO
00415 #ifndef VECTORIZE_FOUR
00416 
00417                 for(int i=0; i < cols-1; i++) {
00418                         MD(r, i)=-MD(r, i);
00419                 }
00420 #endif
00421 #endif
00422 
00423 
00424         }
00425 
00426 
00427         /*
00428          * Note: This function assumes that f divides every elem cleanly &
00429          * that f>0.
00430          */
00431         inline void divide_row(int r, bignum f)
00432         {
00433                 if(CHECK_BOUNDS){
00434                         assert(r>=0 && r<rows);
00435                         assert(f>=0);
00436                 }
00437 #ifdef ENABLE_CHECKS
00438                 for(int i=r*mcols; i<r*mcols+cols-1; i++) {
00439                                 cmatrix[i]/=f;
00440                 }
00441 #endif
00442 
00443                 if(big_rows[r]){
00444                         divide_row_i(r, f);
00445                         return;
00446                 }
00447                 if(f.infinite){
00448                         infinitize_row(r);
00449                         divide_row_i(r, f);
00450                         return;
00451                 }
00452 
00453 
00454                 //for(int i=r*mcols; i<end; i++) {
00455                 //      dmatrix[i].d/=f.data.d;
00456                 //}
00457 
00458 #ifdef VECTORIZE_TWO
00459                 {
00460                         d_type last = MD(r, cols-1);
00461                         v2si c = {f.data.d, f.data.d};
00462                         v2si* cur = (v2si*)&(MD(r, 0));
00463                         v2si* end = (v2si*)&(MD(r, cols-1));
00464                         for(; cur <end; cur++) {
00465 
00466                                 (*cur)/=c;
00467                                 //MD(r, i)*=f.data.d;
00468                         }
00469                         MD(r, cols-1) = last;
00470                 }
00471 #endif
00472 
00473 #ifdef VECTORIZE_FOUR
00474                 {
00475                         d_type last = MD(r, cols-1);
00476                         v4si c = {f.data.d, f.data.d, f.data.d, f.data.d};
00477                         v4si* cur = (v4si*)&(MD(r, 0));
00478                         v4si* end = (v4si*)&(MD(r, cols-1));
00479                         for(; cur <end; cur++) {
00480 
00481                                 (*cur)/=c;
00482                                 //MD(r, i)*=f.data.d;
00483                         }
00484                         MD(r, cols-1) = last;
00485                 }
00486 #endif
00487 
00488 #ifndef VECTORIZE_TWO
00489 #ifndef VECTORIZE_FOUR
00490                 for(int i=0; i<cols-1; i++) {
00491                         MD(r, i)/=f.data.d;
00492                 }
00493 #endif
00494 #endif
00495 
00496 
00497 
00498         }
00499 
00500 
00501 
00502         inline void add_row(int r, bignum f)
00503         {
00504                 if(CHECK_BOUNDS){
00505                         assert(r>=0 && r<rows);
00506                 }
00507 #ifdef ENABLE_CHECKS
00508                 for(int i=r*mcols; i<r*mcols+cols-1; i++) {
00509                         cmatrix[i]+=f;
00510                 }
00511 #endif
00512 
00513                 if(big_rows[r]){
00514                         add_row_i(r, f);
00515                         return;
00516                 }
00517                 if(f.infinite){
00518                         infinitize_row(r);
00519                         add_row_i(r, f);
00520                         return;
00521                 }
00522 
00523                 //long int max = labs(dmatrix[r*mcols].d);
00524                 long int max = labs(MD(r,0));
00525                 //for(int i=r*mcols+1; i<end; i++) {
00526                 //      if(labs(dmatrix[i].d)>max)
00527                 //              max = labs(dmatrix[i].d);
00528                 //}
00529                 for(int i=1; i<cols-1; i++) {
00530                         if(labs(MD(r,i))>max)
00531                                 max = labs(MD(r,i));
00532                 }
00533                 if(bignum::a_overflow(max, f.data.d)){
00534                         infinitize_row(r);
00535                         add_row_i(r, f);
00536                         return;
00537                 }
00538 
00539                 //for(int i=r*mcols; i<end; i++) {
00540                 //      dmatrix[i].d+=f.data.d;
00541                 //}
00542 
00543 #ifdef VECTORIZE_TWO
00544                 {
00545                         d_type last = MD(r, cols-1);
00546                         v2si c = {f.data.d, f.data.d};
00547                         v2si* cur = (v2si*)&(MD(r, 0));
00548                         v2si* end = (v2si*)&(MD(r, cols-1));
00549                         for(; cur <end; cur++) {
00550 
00551                                 (*cur)+=c;
00552                                 //MD(r, i)*=f.data.d;
00553                         }
00554                         MD(r, cols-1) = last;
00555                 }
00556 #endif
00557 
00558 #ifdef VECTORIZE_FOUR
00559                 {
00560                         d_type last = MD(r, cols-1);
00561                         v4si c = {f.data.d, f.data.d, f.data.d, f.data.d};
00562                         v4si* cur = (v4si*)&(MD(r, 0));
00563                         v4si* end = (v4si*)&(MD(r, cols-1));
00564                         for(; cur <end; cur++) {
00565 
00566                                 (*cur)+=c;
00567                                 //MD(r, i)*=f.data.d;
00568                         }
00569                         MD(r, cols-1) = last;
00570                 }
00571 #endif
00572 
00573 #ifndef VECTORIZE_TWO
00574 #ifndef VECTORIZE_FOUR
00575 
00576                 for(int i=0; i<cols-1; i++) {
00577                         MD(r,i)+=f.data.d;
00578                 }
00579                 return;
00580 #endif
00581 #endif
00582 
00583         }
00584 
00585         inline ~matrix()
00586         {
00587                 for(int r=0; r < rows; r++) {
00588                         if(!big_rows[r]) continue;
00589                         delete_bigrow(r);
00590                 }
00591                 delete[] big_rows;
00592                 delete[] dmatrix;
00593 #ifdef ENABLE_CHECKS
00594                 delete[] cmatrix;
00595 #endif
00596         }
00597 
00598         inline void delete_bigrow(int r){
00599                 //for(int i=r*mcols; i < end; i++) {
00600                 //      mpz_clear(dmatrix[i].i);
00601                 //}
00602                 for(int i=0; i < cols; i++) {
00603                         mpz_clear(MI(r,i));
00604                 }
00605         }
00606 
00607         string to_string()
00608         {
00609                 string res;
00610                 if(vars != NULL){
00611                         for(int c=0; c<cols-2; c++) {
00612                                 if((int)vars->size()>c)
00613                                         res+=(*vars)[c];
00614                                 else res+="<u>";
00615                                 res+="\t";
00616                         }
00617                         res+="[c]\t[p]";
00618                         res+="\n";
00619                 }
00620                 for(int r=0; r < rows; r++) {
00621                         for(int c=0; c<cols; c++) {
00622                                 res+=get(r,c).to_string();
00623                                 if(c<cols-1) res+="\t";
00624                         }
00625                         if(big_rows[r]) res+=" (b)";
00626                         res+="\n";
00627                 }
00628                 return res;
00629         }
00630         friend ostream& operator <<(ostream &os,const matrix &obj);
00631 
00632         //----------------------------------------------------------
00633         /*
00634          * Simplex specific functions for the matrix are below here.
00635          */
00636         //----------------------------------------------------------
00637 
00638         inline bignum get_coef_gcd(int r)
00639         {
00640                 if(big_rows[r])
00641                 {
00642                         mpz_t gcd;
00643                         //mpz_init_set(gcd, dmatrix[r*mcols].i);
00644                         mpz_init_set(gcd, MI(r,0));
00645                         mpz_abs(gcd, gcd);
00646                         //for(int i= r*mcols+1; i < r*mcols+cols-2; i++){
00647                         //      mpz_gcd(gcd, gcd, dmatrix[i].i);
00648                         //}
00649                         for(int i= 1; i < cols-2; i++){
00650                                 mpz_gcd(gcd, gcd, MI(r,i));
00651                         }
00652                         bignum b(gcd);
00653                         mpz_clear(gcd);
00654                         return b;
00655                 }
00656                 //long int gcd = labs(dmatrix[r*mcols].d);
00657                 long int gcd = labs(MD(r,0));
00658                 //for(int i= r*mcols+1; i < r*mcols+cols-2; i++){
00659                 //      gcd = bignum::compute_int_gcd(gcd, dmatrix[i].d);
00660                 //}
00661                 for(int i= 1; i < cols-2; i++){
00662                         gcd = bignum::compute_int_gcd(gcd, MD(r,i));
00663                 }
00664                 return bignum(gcd);
00665 
00666         }
00667 
00668         inline vector<string> & get_vars()
00669         {
00670                 return *vars;
00671         }
00672 
00673 
00674         /*
00675          * Returns the pivot element of row r
00676          */
00677         inline int get_pivot(int r)
00678         {
00679                 if(CHECK_BOUNDS){
00680                         assert(r>=0 && r<rows);
00681                 }
00682                 if(!big_rows[r]){
00683                         //return dmatrix[r*mcols+(cols-1)].d;
00684                         return MD(r, cols-1);
00685                 }
00686                 return get(r, cols-1).to_int();
00687         }
00688 
00689         inline bignum get_constant(int r)
00690         {
00691                 if(CHECK_BOUNDS){
00692                         assert(r>=0 && r<rows);
00693                 }
00694                 if(big_rows[r]){
00695                         //return bignum(dmatrix[r*mcols+(cols-2)].i);
00696                         return bignum(MI(r, cols-2));
00697                 }
00698                 //return bignum(dmatrix[r*mcols+(cols-2)].d);
00699                 return bignum(MD(r, cols-2));
00700         }
00701 
00702         inline void set_constant(int r, bignum  b)
00703         {
00704                 set(r, cols-2, b);
00705         }
00706 
00707         inline void set_pivot(int r, int p)
00708         {
00709                 set(r, cols-1, p);
00710         }
00711 
00712         /*
00713          * Returns the index of the first positive
00714          * coefficient in row r.
00715          */
00716         inline int get_first_positive_index(int r)
00717         {
00718                 if(CHECK_BOUNDS){
00719                         assert(r>=0 && r<rows);
00720                 }
00721                 if(!big_rows[r]) {
00722                         for(int c=0; c < cols-2; c++) {
00723                                 //if(dmatrix[row+c].d>0)
00724                                 //      return c;
00725                                 if(MD(r,c)>0)
00726                                         return c;
00727                         }
00728                         return -1;
00729                 }
00730 
00731                 for(int c=0; c < cols-2; c++) {
00732                         //if(mpz_cmp_si(dmatrix[row+c].i, 0)>0)
00733                         //      return c;
00734                         if(mpz_cmp_si(MI(r,c), 0)>0)
00735                                 return c;
00736                 }
00737                 return -1;
00738         }
00739 
00740 
00741 
00742         void pivot(int pivot_row, int pivot_index, bool simplify = true)
00743         {
00744                 if(CHECK_BOUNDS){
00745                         assert(pivot_row>=0 && pivot_row<rows-1);
00746                         assert(pivot_index>=0 && pivot_index< cols-2);
00747                 }
00748 
00749                 bignum pc = get(pivot_row, pivot_index);
00750                 if(pc<0)
00751                         flip_row_sign(pivot_row);
00752 
00753                 /*
00754                  * First, if pivot row is bignum, we turn everything big.
00755                  */
00756                 if(big_rows[pivot_row]){
00757                         for(int r=0; r<rows; r++) {
00758                                 if(big_rows[r]) continue;
00759                                 infinitize_row(r);
00760                         }
00761                         pivot_i(pivot_row, pivot_index, simplify);
00762                         return;
00763                 }
00764 #ifdef ENABLE_CHECKS
00765                 /*
00766                  * Consistency checking only works if everything is a bignum
00767                  */
00768                 assert(false);
00769 #endif
00770                 //record the pivot index
00771                 set(pivot_row, cols-1, pivot_index);
00772 
00773                 //long int pivot_max = labs(dmatrix[pivot_row*mcols].d);
00774                 long int pivot_max = labs(MD(pivot_row, 0));
00775                 //for(int i= pivot_row*mcols+1; i < pivot_row*mcols+cols-1; i++){
00776                 //      if(labs(dmatrix[i].d) > pivot_max){
00777                 //              pivot_max = labs(dmatrix[i].d);
00778                 //      }
00779                 //}
00780                 for(int i= 1; i < cols-1; i++){
00781                         if(labs(MD(pivot_row, i)) > pivot_max){
00782                                 pivot_max = labs(MD(pivot_row, i));
00783                         }
00784                 }
00785 
00786                 //long int pivot_c = dmatrix[pivot_row*mcols+pivot_index].d;
00787                 long int pivot_c = MD(pivot_row,pivot_index);
00788                 for(int r=0; r < rows; r++) {
00789                         if(r==pivot_row) continue;
00790 
00791                         if(big_rows[r]){
00792                                 pivot_row_d_i(pivot_row, pivot_index, pivot_c, r);
00793                                 simplify_row_i(r);
00794                                 continue;
00795                         }
00796 
00797                         //long int cur_c = dmatrix[r*mcols+pivot_index].d;
00798                         long int cur_c = MD(r,pivot_index);
00799                         long int gcd = bignum::compute_int_gcd(pivot_c, cur_c);
00800                         if(gcd == 0) continue;
00801                         long int f_pivot =  cur_c / gcd;
00802                         long int f_cur =  pivot_c / gcd;
00803 
00804                         if(f_cur<0) {
00805                                 f_cur = -f_cur;
00806                                 f_pivot = -f_pivot;
00807                         }
00808 
00809                         /*
00810                          * If we don't have digits left for the multiplication, upgrade to
00811                          * bignums.
00812                          */
00813                         if(bignum::m_overflow(f_cur, cur_c) ||
00814                                         bignum::m_overflow(f_pivot, pivot_c ))
00815                         {
00816 
00817                                 infinitize_row(r);
00818                                 multiply_row(r, f_cur);
00819                                 sub_multiply_row_d_i(r,  pivot_row, f_pivot);
00820                                 simplify_row_i(r);
00821                                 continue;
00822                         }
00823                         multiply_row(r, f_cur);
00824                         if(big_rows[r]){
00825                                 sub_multiply_row_d_i(r,  pivot_row, f_pivot);
00826                                 simplify_row_i(r);
00827                                 continue;
00828                         }
00829                         sub_multiply_row(pivot_max, r, pivot_row, f_pivot);
00830                         if(simplify) simplify_row(r);
00831                 }
00832 
00833 #ifdef ENABLE_CHECKS
00834                         for(int r=0; r < rows-1; r++) {
00835                                 int c_pivot = get_pivot(r);
00836                                 assert(get(r, c_pivot)>0);
00837                                 for(int r2=0; r2<rows; r2++) {
00838                                         if(r2 == r) continue;
00839                                         assert(get(r2, c_pivot) == 0);
00840                                 }
00841 
00842                         }
00843 #endif
00844         }
00845 
00846 
00847 
00848 
00849 //-------------------------------------------------------------
00850 void multiply(matrix & op, matrix & result)
00851 {
00852         assert(cols == op.rows);
00853         assert(result.rows == rows);
00854         assert(result.cols == op.cols);
00855 
00856         for(int r=0; r < result.rows; r++)
00857         {
00858                 for(int c=0; c < result.cols; c++)
00859                 {
00860                         result.set(r,c,dot_product(r,c, op));
00861                 }
00862         }
00863 
00864 }
00865 
00866 inline bignum dot_product(int r, int cc, matrix &op)
00867 {
00868         bignum res;
00869         for(int c=0; c < cols; c++) {
00870                 bignum my_e = get(r, c);
00871                 bignum other_e = op.get(c, cc);
00872                 res+=my_e*other_e;
00873         }
00874         return res;
00875 }
00876 
00877 
00878 
00879 
00880 void compute_redundant_rows(std::set<int> &red_rows)
00881 {
00882         matrix m(rows+1, cols+2);
00883         for(int r=0; r < rows; r++)
00884         {
00885                 for(int c=0; c <cols; c++)
00886                 {
00887                         m.set(r,c,get(r,c));
00888                 }
00889         }
00890 
00891         map<int, int> row_map;
00892         for(int r=0; r <rows; r++) {
00893                 int pivot_c = -1;
00894                 for(int c=0; c < cols; c++) {
00895                         if(row_map.count(c) >0) continue;
00896                         if(m.get(r, c) == 0) continue;
00897                         pivot_c = c;
00898                         row_map[pivot_c] = r;
00899                         break;
00900                 }
00901                 if(pivot_c == -1){
00902                         red_rows.insert(r);
00903                         continue;
00904                 }
00905                 m.pivot(r, pivot_c);
00906         }
00907 
00908 }
00909 
00910 bignum invert(matrix & result)
00911 {
00912 
00913         assert(rows == cols);
00914         assert(result.cols == rows);
00915         assert(result.cols == result.rows);
00916 
00917         matrix m(rows+1, 2*cols+2);
00918         for(int r=0; r < rows; r++){
00919                 for(int c=0; c < cols; c++) {
00920                         m.set(r, c, get(r,c));
00921                 }
00922                 m.set(r, cols+r, 1);
00923         }
00924 
00925         map<int, int> row_map;
00926 
00927         for(int r=0; r <rows; r++) {
00928                 int pivot_c = -1;
00929                 for(int c=0; c < cols; c++) {
00930                         if(row_map.count(c) >0) continue;
00931                         if(m.get(r, c) == 0) continue;
00932                         pivot_c = c;
00933                         row_map[pivot_c] = r;
00934                         break;
00935                 }
00936                 //system was linearely dependent
00937                 assert(pivot_c != -1);
00938                 m.pivot(r, pivot_c);
00939         }
00940 
00941 
00942         bignum gcd = 0;
00943         bignum lcm = 1;
00944 
00945         for(int pivot_c = 0; pivot_c < cols; pivot_c++) {
00946                 assert(row_map.count(pivot_c) > 0);
00947                 int row = row_map[pivot_c];
00948                 bignum p_coef = m.get(row, pivot_c);
00949                 gcd = p_coef.compute_gcd(lcm);
00950                 bignum n_coef = p_coef.divexact(gcd);
00951                 if(n_coef < 0) n_coef = -n_coef;
00952                 lcm*= n_coef;
00953         }
00954 
00955         for(int pivot_c = 0; pivot_c < cols; pivot_c++) {
00956                 int row = row_map[pivot_c];
00957                 bignum p_coef = m.get(row, pivot_c);
00958                 assert(lcm.divisible(p_coef));
00959                 bignum factor = lcm.divexact(p_coef);
00960                 m.multiply_row(row, factor);
00961         }
00962 
00963 
00964         for(int orig_c = 0; orig_c<cols; orig_c++) {
00965                 int row = row_map[orig_c];
00966                 for(int c =0; c < cols; c++){
00967 
00968                         result.set(orig_c,c, m.get(row,c+cols));
00969                 }
00970         }
00971 
00972 
00973         return lcm;
00974 
00975 }
00976 
00977 void vector_multiply(bignum * b, bignum* res)
00978 {
00979         for(int r = 0; r < rows; r++) {
00980                 bignum cur_res = 0;
00981                 for(int c=0; c < cols; c++) {
00982                         bignum a_rc = get(r, c);
00983                         cur_res += a_rc*b[c];
00984                 }
00985                 res[r] = cur_res;
00986         }
00987 }
00988 
00989 void row_vector_multiply(bignum * b, bignum* res)
00990 {
00991         for(int c = 0; c < cols; c++) {
00992                 bignum cur_res = 0;
00993                 for(int r=0; r < rows; r++) {
00994                         bignum a_rc = get(r, c);
00995                         cur_res += a_rc*b[r];
00996                 }
00997                 res[c] = cur_res;
00998         }
00999 }
01000 
01001 //-------------------------------------------------------------
01002 
01003 
01004 private:
01005 
01006 
01007         inline void add_row_i(int r, bignum & f)
01008         {
01009                 if(f.infinite) {
01010                         //for(int i = r*mcols; i<end; i++) {
01011                         //      mpz_add(dmatrix[i].i, dmatrix[i].i, f.data.i);
01012                         //}
01013                         for(int i = 0; i<cols-1; i++) {
01014                                 mpz_add(MI(r, i), MI(r, i), f.data.i);
01015                         }
01016                         return;
01017                 }
01018                 mpz_t temp;
01019                 mpz_init_set_si(temp, f.data.d);
01020                 //for(int i = r*mcols; i<end; i++) {
01021                 //      mpz_add(dmatrix[i].i, dmatrix[i].i, temp);
01022                 //}
01023 
01024                 for(int i = 0; i<cols-1; i++) {
01025                         mpz_add(MI(r, i), MI(r, i), temp);
01026                 }
01027                 mpz_clear(temp);
01028         }
01029 
01030         inline void multiply_row_i(int r, bignum & f)
01031         {
01032                 if(f.infinite) {
01033                         /*for(int i = r*mcols; i<end; i++) {
01034                                 mpz_mul(dmatrix[i].i, dmatrix[i].i, f.data.i);
01035                         }*/
01036                         for(int i = 0; i<cols-1; i++) {
01037                                 mpz_mul(MI(r, i), MI(r, i), f.data.i);
01038                         }
01039                         return;
01040                 }
01041                 //for(int i = r*mcols; i<end; i++) {
01042                 //      mpz_mul_si(dmatrix[i].i, dmatrix[i].i, f.data.d);
01043                 //}
01044                 for(int i = 0; i<cols-1; i++) {
01045                         mpz_mul_si(MI(r, i), MI(r, i), f.data.d);
01046                 }
01047         }
01048 
01049         inline void divide_row_i(int r, bignum & f)
01050         {
01051                 if(f.infinite) {
01052                         for(int i = 0; i<cols-1; i++) {
01053                                 mpz_divexact(MI(r, i), MI(r,i), f.data.i);
01054                         }
01055                         return;
01056                 }
01057                 for(int i = 0; i<cols-1; i++) {
01058 
01059                         mpz_divexact_ui(MI(r,i), MI(r,i), f.data.d);
01060                 }
01061         }
01062 
01063         inline void simplify_row(int r)
01064         {
01065                 if(CHECK_BOUNDS){
01066                         assert(r>=0 && r< rows);
01067                 }
01068                 if(big_rows[r]){
01069                         simplify_row_i(r);
01070                         return;
01071                 }
01072                 long int gcd = labs(MD(r,0));
01073                 for(int i= 1; i < cols-1; i++) {
01074                         gcd = bignum::compute_int_gcd(gcd, MD(r,i));
01075                 }
01076                 if(gcd == 0 || gcd == 1) return;
01077                 for(int i = 0; i <cols-1; i++) {
01078                         MD(r,i)/=gcd;
01079                 }
01080         }
01081 
01082 
01083 
01084         inline void simplify_row_i(int r)
01085         {
01086                 mpz_t gcd;
01087                 mpz_init_set(gcd, MI(r,0));
01088                 mpz_abs(gcd, gcd);
01089                 for(int i=1; i < cols-1; i++) {
01090                         mpz_gcd(gcd, gcd, MI(r,i));
01091                 }
01092                 if(mpz_cmp_si(gcd, 0) ==0 || mpz_cmp_si(gcd, 1) ==0){
01093                         mpz_clear(gcd);
01094                         return;
01095                 }
01096                 for(int i=0; i < cols-1; i++) {
01097                         mpz_divexact(MI(r,i), MI(r,i), gcd);
01098                 }
01099                 mpz_clear(gcd);
01100         }
01101 
01102         inline void pivot_row_d_i(int pivot_row, int pivot_index,
01103                         long int piv_c, int r)
01104         {
01105                 if(CHECK_BOUNDS){
01106                         assert(big_rows[r]);
01107                         assert(!big_rows[pivot_row]);
01108                 }
01109 
01110 
01111                 mpz_t pivot_c;
01112                 mpz_init_set_si(pivot_c, piv_c);
01113 
01114 
01115                 mpz_t cur_c;
01116                 mpz_t gcd;
01117                 mpz_t f_pivot;
01118                 mpz_t f_cur;
01119 
01120                 mpz_init(cur_c);
01121                 mpz_init(gcd);
01122                 mpz_init(f_pivot);
01123                 mpz_init(f_cur);
01124 
01125                 mpz_set(cur_c, MI(r,pivot_index));
01126                 mpz_gcd(gcd, pivot_c, cur_c);
01127                 if(mpz_cmp_si(gcd, 0) == 0){
01128                         mpz_clear(pivot_c);
01129                         mpz_clear(cur_c);
01130                         mpz_clear(gcd);
01131                         mpz_clear(f_pivot);
01132                         mpz_clear(f_cur);
01133                         return;
01134                 }
01135 
01136                 mpz_divexact(f_pivot, cur_c, gcd);
01137                 mpz_divexact(f_cur, pivot_c, gcd);
01138                 if(mpz_cmp_si(f_cur, 0)<0)
01139                 {
01140                         mpz_neg(f_pivot, f_pivot);
01141                         mpz_neg(f_cur, f_cur);
01142                 }
01143                 mpz_t cur;
01144                 mpz_init(cur);
01145                 for(int c=0; c<cols-1; c++) {
01146                         mpz_mul(MI(r, c), MI(r, c), f_cur);
01147                         mpz_set_si(cur, MD(pivot_row, c));
01148                         mpz_submul(MI(r, c), cur, f_pivot);
01149                 }
01150                 mpz_clear(cur);
01151 
01152                 mpz_clear(pivot_c);
01153                 mpz_clear(cur_c);
01154                 mpz_clear(gcd);
01155                 mpz_clear(f_pivot);
01156                 mpz_clear(f_cur);
01157         }
01158 
01159 
01160         /*
01161          * Dest is bignum, pivot row is not. Dest has already been mutiplied.
01162          */
01163         inline void sub_multiply_row_d_i(int dest_r, int pivot_r,  long int f_pivot)
01164         {
01165                 mpz_t f_piv;
01166                 mpz_t cur;
01167                 mpz_init(cur);
01168                 mpz_init_set_si(f_piv, f_pivot);
01169                 for(int c=0; c<cols-1; c++) {
01170                         mpz_set_si(cur, MD(pivot_r, c));
01171                         mpz_submul(MI(dest_r, c), cur, f_piv);
01172                 }
01173                 mpz_clear(f_piv);
01174                 mpz_clear(cur);
01175         }
01176 
01177 
01178         /*
01179          * This function assumes that both the pivot row and the dest row
01180          * are currently NOT bignums. Max is maximum abs value in pivot row.
01181          */
01182         inline void sub_multiply_row(long int max, int dest_r, int pivot_r,
01183                         long int f_pivot)
01184         {
01185                 if(bignum::m_overflow(max, f_pivot))
01186                 {
01187                         infinitize_row(dest_r);
01188                         sub_multiply_row_d_i(dest_r, pivot_r, f_pivot);
01189                         return;
01190                 }
01191                 long int max_dest = MD(dest_r, 0);
01192                 for(int i=1; i < cols-1; i++) {
01193                         if(labs(MD(dest_r, i)) > max_dest) {
01194                                 max_dest = labs(MD(dest_r, i));
01195                         }
01196                 }
01197                 long int max_pf = max*f_pivot;
01198                 if(bignum::a_overflow(max_pf, max_dest))
01199                 {
01200                         infinitize_row(dest_r);
01201                         sub_multiply_row_d_i(dest_r, pivot_r, f_pivot);
01202                         return;
01203                 }
01204 
01205 #ifdef VECTORIZE_TWO
01206                 {
01207                         d_type last = MD(dest_r, cols-1);
01208                         v2si c = {f_pivot, f_pivot};
01209                         v2si* l = (v2si*)&MD(dest_r, 0);
01210                         v2si* l_end = (v2si*)&MD(dest_r, cols-1);
01211                         v2si* r = (v2si*)&MD(pivot_r, 0);
01212                         for(; l<l_end; l++, r++) {
01213                                 (*l) -= (*r)*c;
01214                         }
01215                         MD(dest_r, cols-1) = last;
01216                 }
01217 #endif
01218 
01219 #ifdef VECTORIZE_FOUR
01220                 {
01221                         d_type last = MD(dest_r, cols-1);
01222                         v4si c = {f_pivot, f_pivot, f_pivot, f_pivot};
01223                         v4si* l = (v4si*)&MD(dest_r, 0);
01224                         v4si* l_end = (v4si*)&MD(dest_r, cols-1);
01225                         v4si* r = (v4si*)&MD(pivot_r, 0);
01226                         for(; l<l_end; l++, r++) {
01227                                 (*l) -= (*r)*c;
01228                                 //MD(r, i)*=f.data.d;
01229                         }
01230                         MD(dest_r, cols-1) = last;
01231                 }
01232 #endif
01233 
01234 #ifndef VECTORIZE_TWO
01235 #ifndef VECTORIZE_FOUR
01236 
01237                 for(int c=0; c < cols-1; c++) {
01238                         MD(dest_r, c) -= MD(pivot_r, c) * f_pivot;
01239                 }
01240 #endif
01241 #endif
01242 
01243 
01244         }
01245 
01246         inline void pivot_i(int pivot_row, int pivot_index, bool simplify)
01247         {
01248                 mpz_set_si(MI(pivot_row, cols-1), pivot_index);
01249 
01250                 //now, all of the matix is a bignum.
01251                 mpz_t pivot_c;
01252                 mpz_init_set(pivot_c, MI(pivot_row, pivot_index));
01253 
01254                 mpz_t cur_c;
01255                 mpz_t gcd;
01256                 mpz_t f_pivot;
01257                 mpz_t f_cur;
01258 
01259                 mpz_init(cur_c);
01260                 mpz_init(gcd);
01261                 mpz_init(f_pivot);
01262                 mpz_init(f_cur);
01263 
01264 #ifdef ENABLE_CHECKS
01265                 cmatrix[pivot_row*mcols+cols-1] = pivot_index;
01266 #endif
01267 
01268                 for(int r=0; r < rows; r++){
01269                         if(r==pivot_row) continue;
01270 
01271 #ifdef ENABLE_CHECKS
01272                         bignum b_pivot_c = cmatrix[pivot_row*mcols+pivot_index];
01273                         bignum b_cur_c = cmatrix[r*mcols+pivot_index];
01274                         bignum b_gcd = b_cur_c.compute_gcd(b_pivot_c);
01275                         bignum b_f_pivot = b_cur_c/b_gcd;
01276                         bignum b_f_cur = b_pivot_c/b_gcd;
01277                         assert(b_f_cur >= 0);
01278                         for(int c=0; c<cols-1; c++) {
01279                                 cmatrix[r*mcols+c]*=b_f_cur;
01280                                 cmatrix[r*mcols+c]-=cmatrix[pivot_row*mcols+c]*b_f_pivot;
01281                         }
01282 #endif
01283 
01284 
01285                         mpz_set(cur_c, MI(r, pivot_index));
01286                         if(mpz_cmp_si(cur_c, 0) == 0) continue;
01287                         mpz_gcd(gcd, pivot_c, cur_c);
01288                         if(mpz_cmp_si(gcd, 0) == 0) continue;
01289                         mpz_divexact(f_pivot, cur_c, gcd);
01290                         mpz_divexact(f_cur, pivot_c, gcd);
01291                         if(mpz_cmp_si(f_cur, 0)<0)
01292                         {
01293                                 mpz_neg(f_pivot, f_pivot);
01294                                 mpz_neg(f_cur, f_cur);
01295                         }
01296 
01297                         for(int c=0; c<cols-1; c++) {
01298                                 mpz_mul(MI(r, c), MI(r, c), f_cur);
01299                                 mpz_submul(MI(r, c),
01300                                                 MI(pivot_row, c), f_pivot);
01301                         }
01302                         if(simplify) simplify_row_i(r);
01303 
01304                 }
01305                 mpz_clear(pivot_c);
01306                 mpz_clear(cur_c);
01307                 mpz_clear(gcd);
01308                 mpz_clear(f_pivot);
01309                 mpz_clear(f_cur);
01310 
01311 #ifdef ENABLE_CHECKS
01312                 check_consistency();
01313 #endif
01314         }
01315 
01316 
01317         inline void infinitize_row(int r)
01318         {
01319                 for(int i = cols-1; i >= 0; i--)
01320                 {
01321                         long int val = MD(r, i);
01322                         mpz_init_set_si(MI(r, i), val);
01323                         bignum t(MI(r,i));
01324                 }
01325                 big_rows[r] = true;
01326         }
01327 
01328         void check_consistency()
01329         {
01330 
01331 #ifdef ENABLE_CHECKS
01332                 cout << "checking consistency..." << endl;
01333 
01334                 for(int r=0; r < rows; r++) {
01335                         for(int c=0; c < cols; c++){
01336                                 int index = r*mcols +c;
01337                                 bignum b = bignum(dmatrix[index].i);
01338                                 if(b != cmatrix[index]){
01339                                         cout << "Error: " << r << " " << c << endl;
01340                                         assert(false);
01341                                 }
01342 
01343                         }
01344                 }
01345 #endif
01346 
01347 
01348         }
01349 
01350 
01351 };
01352 
01353 #endif /* MATRIX_H_ */