/*************************************************************

    This program is a test driver demonstraiting the SHA C++ classes I adapted
    from a C version written by Aaron D. Gifford
    (as of 11/22/2004 his code could be found at http://www.adg.us/computers/sha.html).
    Attempts to contact him were unsuccessful.  I make use of his input files
    for testing.  The format should be flexible enough to add any number
    of test files, the only restriction is that they be in a subdirectory
    called 'testvectors' and all the files begin with 'vector'.  The program can
    be called with a 'verbose' mode to provide details of each test (not all tests
    are available for all bit-lengths of each hash).  With verbose off only
    a pass/fail message is written.  For your amusement you can also run a speed
    test of the algorithm.  It generates the hash for a user-defined number
    of iterations (default is 1 iteration) on a 16 megabyte (2^24) string of 'a's.
    To execute the program with the speed test run it with 'speedtest' on the 
    command line.

    If you use this implementation somewhere I would like to be credited
    with my work (a link to my page below is fine).  I add no license
    restriction beyond any that is made by the original author.  This
    code comes with no warrenty expressed or implied, use at your own
    risk!

    Keith Oxenrider
    koxenrider[at]sol[dash]biotech[dot]com
    The latest version of this code should be available via the page
    sol-biotech.com/code.

*************************************************************/

#include <ctime>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

#ifdef WIN32
    #include <io.h>
    #pragma warning(disable: 4786) //useful with Visual Studio C++ version 6
#else //presumed *nix
    #include <sys/stat.h>
    #include <dirent.h>
#endif

#include "sha2.h"

using namespace std;


#ifdef WIN32
    void listDir(const char * sdir, vector < string > &vectData){
        string str, strDir;
        struct _finddata_t c_file;
        long hFile;

        // Find first file in current directory
        str = sdir + (string) "\\*";
        if( (hFile = _findfirst( str.c_str(), &c_file )) == -1L ){
            cerr << "Error opening " << str << "!\n";
            return;
        }

        do{
            //skip if find . and ..
            if ((strcmp(c_file.name, ".") == 0 ||  strcmp(c_file.name, "..") == 0)) continue;
            if (c_file.attrib & _A_SUBDIR) continue;
            if (strncmp(c_file.name, "vector", 6)) continue;
            str = sdir + (string) "\\" + (string) c_file.name;
            vectData.push_back(str);
        }while(_findnext( hFile, &c_file ) == 0);
       _findclose( hFile );
       return;
    }
#else //presume *nix
    void listDir(const char * sdir, vector < string > &vectData){
        string str, strDir;
        struct  dirent  *dptr;
        struct  stat st;
        DIR     *dirp;
                 
        if( (dirp = opendir(sdir)) == NULL ) {
            cerr << "Error opening " << sdir;
            perror(": ");
            return;
        }
        while(dptr = readdir(dirp)) {
            //skip if find . and ..
            if ((strcmp(dptr->d_name, ".") == 0 ||  strcmp(dptr->d_name, "..") == 0)) {
                continue;
            }
            strDir = sdir + (string)"/" + (string)dptr->d_name;
            int statRet = stat(strDir.c_str(), &st);
            if (statRet != -1 && S_ISDIR(st.st_mode)) continue;
            if (strncmp(dptr->d_name, "vector", 6)) continue;
            vectData.push_back(strDir);
        }
        closedir(dirp);
    }
#endif

enum SHA_FILE_TYPES{
    enuSHA1   = sha2::enuSHA1,
    enuSHA224 = sha2::enuSHA224,
    enuSHA256 = sha2::enuSHA256,
    enuSHA384 = sha2::enuSHA384,
    enuSHA512 = sha2::enuSHA512,
    enuNOTYPE,
    enuDESCR,
    enuASCII,
    enuHEX,
    enuREPEAT,
    enuRANGE,
    enuLAST
};

void test(bool verbose, bool validate, bool speedtest, unsigned int ITERS){
    string strBuf, strTmp;
    fstream fin;
    string strData[enuLAST];
    SHA_FILE_TYPES type = enuNOTYPE, tmpType;
    unsigned int i, j, cnt, hextmp;
    bool passed, tmppassed;
    unsigned char hexstr[3];
    sha2 mySha2;
    vector < string > vectData;

    hexstr[2] = '\0';

    if (mySha2.IsBigEndian()) cout << "This machine is Big Endian.\n";
    else cout << "This machine is Little Endian.\n";

    if (validate){
        listDir("testvectors", vectData);
        for (i=0; i<vectData.size(); i++){

            passed = true;
            for (j=0; j<enuLAST; j++) strData[j] = "";
            cout << "processing " << vectData[i] << "... ";
            fin.open(vectData[i].c_str(), fstream::in);
            if (!fin){
                cerr << "\tCan't open " << vectData[i].c_str() << "!\n";
            }else{
                while (true){
                    getline(fin, strBuf);
                    if (fin.eof()) break;
                    if (strBuf == "") continue;
                    if (strBuf[strBuf.size()-1] == '\r') strBuf.erase(strBuf.size()-1, 1);
                    if (strBuf[0] == '\r') strBuf.erase(0, 1);
                    if (strBuf == "") continue;
                    if (strncmp(strBuf.c_str(), "    ", 4)){
                        if (!strncmp(strBuf.c_str(), "DESCRIPTION", 11)){
                            type = enuDESCR;
                        }else if (!strncmp(strBuf.c_str(), "SHA1", 4)){
                            type = enuSHA1;
                        }else if (!strncmp(strBuf.c_str(), "SHA224", 6)){
                            type = enuSHA224;
                        }else if (!strncmp(strBuf.c_str(), "SHA256", 6)){
                            type = enuSHA256;
                        }else if (!strncmp(strBuf.c_str(), "SHA384", 6)){
                            type = enuSHA384;
                        }else if (!strncmp(strBuf.c_str(), "SHA512", 6)){
                            type = enuSHA512;
                        }else if (!strncmp(strBuf.c_str(), "ASCII", 5)){
                            type = enuASCII;
                        }else if (!strncmp(strBuf.c_str(), "HEX", 3)){
                            type = enuHEX;
                        }else if (!strncmp(strBuf.c_str(), "REPEAT", 6)){
                            type = enuREPEAT;
                        }else if (!strncmp(strBuf.c_str(), "RANGE", 5)){
                            type = enuRANGE;
                        }else{
                            type = enuNOTYPE;
                            cerr << "Unknown type: [" << strBuf << "]\n";
                        }
                    }else{
                        stringstream iostr;
                        iostr << strBuf;
                        iostr >> strTmp;
                        switch(type){
                            case enuDESCR:  break;
                            case enuSHA1:   strData[enuSHA1] += strTmp; break;
                            case enuSHA224: strData[enuSHA224] += strTmp; break;
                            case enuSHA256: strData[enuSHA256] += strTmp; break;
                            case enuSHA384: strData[enuSHA384] += strTmp; break;
                            case enuSHA512: strData[enuSHA512] += strTmp; break;
                            case enuASCII: {
                                //erase leading spaces
                                while (isspace(strBuf[0])) strBuf.erase(0, 1);
                                //erase trailing spaces
                                while (isspace(strBuf[0])) strBuf.erase(strBuf.size()-1, 1);
                                //erase leading and trailing quotes
                                if (strBuf[0] == '"') strBuf.erase(0, 1);
                                if (strBuf[strBuf.size()-1] == '"') strBuf.erase(strBuf.size()-1, 1);
                                strData[enuASCII] += strBuf;
                                break;
                            }
                            case enuHEX: {
                                for (j=2; j<strTmp.size(); j+=2){
                                    hexstr[0] = strTmp[j];
                                    hexstr[1] = strTmp[j+1];
                                    sscanf((const char *)hexstr, "%x", &hextmp);
                                    strData[enuHEX] += (char)hextmp;
                                }
                                break;
                            }
                            case enuREPEAT: strData[enuREPEAT] += strTmp; break;
                            case enuRANGE: strData[enuRANGE] += strTmp; break;
                            default:        cerr << "Invalid type!\n";
                        }
                    }
                    if (!fin.good()) break;
                }
            }
            fin.close();
            fin.clear();

            if (strData[enuRANGE].size()){//special, 1 type case
                strTmp = "";
                for (j=0; j<128; j++)
                    strTmp += j;
                tmpType = enuASCII;
                strData[enuASCII] = strTmp;
            }else if (strData[enuREPEAT].size()){
                tmpType = enuREPEAT;
                cnt = atoi(strData[enuREPEAT].c_str());
                strTmp.reserve(cnt * strData[enuASCII].size());
                strTmp = "";
                for (j=0; j<cnt; j++)
                    strTmp += strData[enuASCII];
                strData[enuREPEAT] = strTmp;
            }
            else if (strData[enuHEX].size()) tmpType = enuHEX;
            else if (strData[enuASCII].size()) tmpType = enuASCII;

            for (j=sha2::enuSHA1; j<sha2::enuSHA_LAST; j++){
                if (strData[j].size()) {
                    const string &strHash = mySha2.GetHash((sha2::SHA_TYPE) j, 
                                    (const unsigned char *)strData[tmpType].c_str(), 
                                    strData[tmpType].size());
                    tmppassed = (strData[j] == strHash);
                    if (!tmppassed) passed = tmppassed;
                    if (tmppassed && verbose) cout << "\n\t" << mySha2.GetTypeString() << "   passed";
                    else if (!tmppassed) cout << "\n\t" << mySha2.GetTypeString() << "   FAILED";
                }
            }

            if (passed && !verbose) cout << "passed.\n";
            else if (!passed && !verbose)  cout << "FAILED!\n";
            else cout << "\n";
        }
    }//end if validate


    if (speedtest){
        //do some performance testing...
        clock_t t_start, t_stop;
        double runTime;
        string strDat, strType;

        strDat.reserve(16777216);//2^24
        strDat = "a";
        for (i=0; i<24; i++) strDat += strDat;

        for (j=sha2::enuSHA1; j<sha2::enuSHA_LAST; j++){
            t_start = clock();
            for (i=0; i<ITERS; i++){
                mySha2.GetHash((sha2::SHA_TYPE) j, (const unsigned char *)strDat.c_str(), 
                            strDat.size());
            }
            t_stop = clock();
            runTime = (t_stop - t_start) / (double)CLOCKS_PER_SEC;
            cout << "It took " << runTime << " sec to run ";
            cout << mySha2.GetTypeString();
            if (ITERS == 1)
                cout << " once on ";
            else
                cout << " for " << ITERS << " iterations on ";
            cout << strDat.size() << " bytes.\n";
        }
    }//end if speedtest
    return;
}

void usage(const char *progName){
    cerr << progName << " <optional 'verbose'> <optional 'validate'>";
    cerr << " <optional 'speedtest'> <optional speedtest iterations (int, default 1)>\n";
    cerr << "\tYou must choose either 'validate', 'speedtest' or both\n";
    cout << "Press return to exit\n";
    cin.sync();
    cin.get();
    exit(1);
}

int main(int argc, char *argv[]){
    bool verbose = false, validate = false, speedtest = false;
    unsigned int speedTestIters = 1;
    int tmp;

    if (argc > 1){
        for (int i=1; i<argc; i++){
            if (string(argv[i]) == "verbose") verbose = true;
            if (string(argv[i]) == "validate") validate = true;
            if (string(argv[i]) == "speedtest") speedtest = true;
            if ((tmp = atoi(argv[i]))){
                if (tmp > 1) speedTestIters = tmp;
            }
        }
    }else
        usage(argv[0]);

    if (verbose) cout << "Verbose is set.\n";
    if (validate) cout << "Validate is set.\n";
    if (speedtest){
        cout << "Speedtest is set, running " << speedTestIters << " iterations.\n";
    }

    if (!validate && !speedtest){
        usage(argv[0]);
    }

    test(verbose, validate, speedtest, speedTestIters);
    
    cout << "Press return to exit\n";
    cin.sync();
    cin.get();
    return 0;
}
