/*
    Expression parser (C) 2004 lifejunkie

    This program is free software; you can redistribute it and/or
    modify it under the terms of the GNU General Public License
    as published by the Free Software Foundation; either version 2
    of the License, or (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

    Contact: huntjas2@msu.edu
*/

#include <iostream.h>
#include <stdio.h>
#include <string.h>
#include <vector.h>
#include <math.h>


/***
* Class definitions for a parse tree class
***/


/* Base class, all other type are derrived from this class */
class element
{
  public:
    virtual void show(void) { };          // display function for the derived class
    enum {t_num, t_op, t_exp} node_type;  // the data type for each node
};


/* Number element of an expression, derived from element */
class number: public element
{
  public:
    number(double num) { value = num; node_type = t_num; }; // constructor, sets node_type
    void show(void) { cout << value; };  // display the value
    double value; //  the value of this element
};


/* Operation element of an expression, derived from element*/
class operation: public element
{
  public:
    operation(char the_op) { op = the_op; node_type = t_op;}; // constructor, sets node_type
    void show( void ) { cout << op; }; // display the operation
    char op; // the actual operation
};


/* Expression element */
class expression: public element
{
  public:
    // a vector containt POINTERS to the sub elements
    vector<element *> elements;

    expression(void) { node_type = t_exp; }; // constructor
    ~expression(void) { clear(); }// free all of the added elements

    void add_num(float num) { elements.push_back(new number(num)); }; // add a number to the expression
    void add_op(char op) { elements.push_back(new operation(op)); }; // add a operation to the expression
    void add_exp(expression *exp) { elements.push_back(exp); };     // add a subexpression to the

    void  apply_operation(expression *exp, char the_op); // scan the expression and apply the operation

    double evaluate(void); // return the value of the expression
    void clear(void);   // remove all values from the expression
    void show( void );  // display the expression
};



// a variable definition
class variable_def
{
  public:
    char *name;     // name of the variable
    double value;   // the values of the variable
};


/***
* That's it for the class definitions
***/




/***
* The class methods
***/

// free memory and remove all elements form an expression
void expression::clear(void)
{
    // free the memory for the elements in the vector
    for (int i = 0; i < elements.size(); i++)
        delete elements[i];

    // empty the vector
    elements.clear();
}

// this displays an expression
void expression::show(void)
{
    cout << "(";
    for (int i = 0; i < elements.size(); i++)
        elements[i]->show();
    cout << ")";
}

// scans the expresison and applies the operation in the second agrument
void expression::apply_operation(expression *exp, char the_op)
{
    expression *tmp = exp;  // create a pointer to the expression

    for (int i = 0; i < tmp->elements.size();i++)
    {
        if (tmp->elements[i]->node_type == t_op)
        {
            // create a pointer to the operation that we are going to apply
            operation *tmp_op = (operation *) tmp->elements[i];

            // make sure the expresison makes sense
            if ((tmp->elements[i-1]->node_type != t_num) ||
               (tmp->elements[i+1]->node_type != t_num))
            {
                exp->clear();  // wipe the expresison out
                cout << "Invalid expression." << endl;
                return;
            }

            // create pointers to the arguments
            number *left = (number *) tmp->elements[i-1];
            number *right = (number *) tmp->elements[i+1];

            if (tmp_op->op == the_op)
            {
                double result;

                if (the_op == '+') result = left->value + right->value; else  // addition
                if (the_op == '-') result = left->value - right->value; else  // subtraction
                if (the_op == '*') result = left->value * right->value; else  // multiplication
                if (the_op == '^') result = pow(left->value, right->value); else  // power
                if ((the_op == '/') && (right->value != 0))                   // division
                    result = left->value / right->value; else
                result = 0; // unknown or division by zero

                // replace the left hand argument with the new value
                tmp->elements[i-1] = new number(result);

                // free the memory for the operation
                delete tmp->elements[i];

                // free the memory for the right hand numbers
                delete tmp->elements[i+1];

                // remove the operation from the elements
                tmp->elements.erase(tmp->elements.begin() + i);

                // remove the left hand number
                tmp->elements.erase(tmp->elements.begin() + i);

                i-= 2; // we removed two elements, go back
            }
        }
    }
}


// evaluate an expression and return the value
double expression::evaluate(void)
{
    expression *tmp = this; // create a pointer to the expression

    // scan for and collapse all of the sub expressions
    for (int i = 0; i < tmp->elements.size();i++)
    {
        // check for sub expression
        if (tmp->elements[i]->node_type == t_exp)
        {
            // create a pointer for out sub expression
            expression *sub_exp = (expression *)tmp->elements[i];

            // evaluate the subexpression
            double num = sub_exp->evaluate();

            // free the memory for the old subexpression
            delete tmp->elements[i];

            // replace it with out number
            tmp->elements[i] = new number(num);

            // we movified the vector, go back one to check the next value
            i--;
        }
    }

    // this applies the order of operations
    apply_operation(tmp, '^');
    apply_operation(tmp, '*');
    apply_operation(tmp, '/');
    apply_operation(tmp, '+');
    apply_operation(tmp, '-');

    // the the expression reduce fully?
    if (tmp->elements.size() == 1)
    {
        number *last_value = (number *) tmp->elements[0];
        return(last_value->value);
    }
    else
    {
        cout << "Invalid expression" << endl;
        return(0);
    }
}

// contains the table of defined variables
class var_table
{
  public:
    vector<variable_def> vars;

    ~var_table() { clear(); };

    void add(char *name, float the_val);  // add a variable and value
    void clear(void);  // free the memory
    void show(void); // display the table
};


// add a variable name and values to the table
void var_table::add(char *the_name, float the_val)
{
    variable_def tmp;
    tmp.name = new char[strlen(the_name)];  // get some memory
    strcpy(tmp.name, the_name);
    tmp.value = the_val;

    vars.push_back(tmp); // add it
}

// free the table
void var_table::clear()
{
    for (int i = 0; i < vars.size(); i++)
        delete vars[i].name;

    vars.clear();
}

// display the table
void var_table::show(void)
{
    for (int i = 0; i < vars.size(); i++)
        cout << "(" << vars[i].name << ") = " << vars[i].value << endl;
}

/***
* End of the class methods
***/




/***
* Parsing functions
***/

// returns true if the argument is a digit
int inline isdigit(char ch)
{
    return((ch >= '0' && ch <= '9') || ch == '.');
}

// returns true if the argument whitespace
int inline iswhitespace(char ch)
{
    return(ch == ' ' || ch == '\t');
}

// returns true if the argument is a valid operation
int inline isoperation(char ch)
{
    return(ch == '+' || ch == '-' || ch == '*' || ch == '/' || ch == '^');
}

// returns true if the argument is a valid identifier
int inline isvarname(char ch)
{
    return((ch >= 'a' && ch <= 'z') || ch == '_');
}

// seeks in the stream until a non white space value
void eat_whitespace(char *str, int &pos)
{
    // is whitespave returns zero when the null term is reached
    while (iswhitespace(str[pos])) pos++;
}

// returns the next number in a string
double get_number(char *str, int &pos, int len)
{
    bool getting_digit = true;

    char num_str[len + 1];  // get memory for the number string
    int num_str_pos = 0;  // position in the string

    // seek until a non digit or null
    while (getting_digit)
    {
        //eat white space
        eat_whitespace(str, pos);

        // stop at end of string
        if (!str[pos]) break;

        // add the next digit to the number string
        if (isdigit(str[pos]))
            num_str[num_str_pos++] = str[pos++];
        else
            getting_digit = false;
     }

    // null terminate the number string
    num_str[num_str_pos] = 0;

    // return the floating point value
    return(atof(num_str));
}


// parse the expression in the string
int parse_expression(char *exp_str, expression &exp, var_table &the_vars)
{
    // we need the expression length later
    int exp_len = strlen(exp_str);

    // inital position is zero
    int pos = 0;

    // loop until null terminator
    while (exp_str[pos])
    {
        // eat the white space
        eat_whitespace(exp_str, pos);

        // break if we seeked to the end
        if (pos == exp_len) break;

        if (isdigit(exp_str[pos])) exp.add_num(get_number(exp_str, pos, exp_len)); else  // number?
        if (isoperation(exp_str[pos])) exp.add_op(exp_str[pos++]); else       // operation ?
        if (exp_str[pos] == ')')  return(pos); else // this is the termination of a sub expression
        if (exp_str[pos] == '(')
        {
            // create and parse a subexpression
            expression *tmp = new expression; // allocate memory for the new expression
            pos += parse_expression(exp_str + pos + 1, *tmp, the_vars) + 2;  // parse the sub expression
            exp.add_exp(tmp); // add it to our main expression
        }else   // this section checks for variables
        {
            // grab the next object
            char obj_name[1024];
            int name_pos = 0, i;

            // construct the object name string
            while (isvarname(exp_str[pos]))
                obj_name[name_pos++] = exp_str[pos++];

            obj_name[name_pos++] = 0;  // term the string

            for (i = 0; i < the_vars.vars.size(); i++)
                if (!strcmp(the_vars.vars[i].name, obj_name))   // did we find this object
                {
                    exp.add_num(the_vars.vars[i].value); // add the actual values to the expression
                    break; // kill the for loop
                }

            // for loop overran, object not found
            if (i == the_vars.vars.size())
            {
                cout << "(" << obj_name << ") not found" << endl;
                return(0);
            }

        }
    }
    return(pos); // return number of bytes consumed
}


/*****************************
*   End of Parsing functions *
*****************************/



int main(int argc, char *argv[])
{
    expression exp;   // the parse tree for the expression
    var_table vars;

    char string_exp[1024];  // the expression in string form

    cout << "Basic expression parser. Type exit to quit" << endl << endl;
    cout << "Hit enter to stop adding variables" << endl;

    while(1)
    {
        vars.clear(); // clear the var table

        // get variable names and values
        while(1)
        {
            char var_name[1024];
            char var_value[1024];

            cout << "Enter variable name: ";
            cin.getline(var_name, 1024);

            if (!strcmp(var_name, "exit")) return(0); // program quit

            if (!var_name[0]) break; // done adding vars

            cout << "Enter value: ";
            cin.getline(var_value, 1024);

            vars.add(var_name, atoi(var_value));
        }

        cout << "Variable table: " << endl;
        vars.show();

        // make sure the expression is empty
        exp.clear();

        // get the expression in string form
        cout << "Enter expression: ";
        cin.getline(string_exp, 1024);

        // exit on break
        if (!strcmp(string_exp, "exit")) return(0);

        // create the parse tree
        parse_expression(string_exp, exp, vars);

        // show the parse tree
        cout << "Parse tree expression: ";
        exp.show();
        cout << endl;

        double val = exp.evaluate(); // evaluate the expression

        cout << "Value: " << val << endl;
    }

    return 0;
}


