/** * A naive implementation of reverse accumulation of the first order derivatives * Written on 24 January 2010, Geilo, Norway * Author: Daniel Wilczak * * Remarks: the code is as bad as possible, but it shows that the basic AD * can be implemented within a few minutes. * * The code contains some errors that can be easily fixed. * For example, there is no memory management, i.e. * allocated memory is never released. * One needs to implement a virtual destructor. * * Compile (linux): g++ backwardDiffExample.cpp -o backwardDiffExample * Run: ./backwardDiffExample * * Comments added on 27 January 2010, Oslo Airport */ #include #include // The basic class for Reverse Mode. // We will record the expression as a tree. // When the expression is evaluated the actual value // of all intermediate variables will be stored in // the member variable 'value'. // Moreover, the DAG of the expression will be automatically recorded. // // The virtual function diff when called for the root of the tree // will accumulate the first order derivatives. class B { public: // When only the value of a node is known, we assume that it is a variable. // Hence, there are no left and right subtrees. B(double _x) : value(_x), der(0.), left(NULL), right(NULL) {} // A general constructor to initialize an object. B(double _x, B* _left, B*_right) : value(_x), der(0.), left(_left), right(_right) {} // Virtual function to accumulate derivatives. // By default it does nothing. This is for variables, constants. virtual void diff(){} // Actual value and derivative of the output expression wrt // the this intermediate variable. double value; double der; // Left and right subexpressions, if exist B* left; B* right; }; // Class BSum inherites from the class B. // It redefines virtual function diff for addition. class BSum : public B { public: // When create a new node that represents addition // we just add the values from left and right subexpressions // and store the pointers to nodes for those subexpressions. BSum(B& _left, B& _right) : B(_left.value + _right.value, &_left, &_right) {} void diff() { // we have computed df/dv and stored in the member variable 'der' // if v=left+right then // df/dleft = df/dv // and // df/dright = df/dv // so we just simply add these values left->der += der; right->der += der; // accumulate left and right subexpressions left->diff(); right->diff(); } }; // Class BDif inherites from the class B // It redefines virtual function diff for subtraction class BDif : public B { public: BDif(B& _left, B& _right) : B(_left.value - _right.value, &_left, &_right) {} void diff() { // we have computed df/dv and stored in the member variable 'der' // if v=left-right then // df/dleft = df/dv // and // df/dright = -df/dv left->der += der; right->der -= der; // accumulate left and right subexpressions left->diff(); right->diff(); } }; // Class BMul inherites from the class B // It redefines virtual function diff for multiplication class BMul : public B { public: BMul(B& _left, B& _right) : B(_left.value * _right.value, &_left, &_right) {} void diff() { // we have computed df/dv and stored in the member variable 'der' // if v=left*right then // df/dleft = df/dv * right // and // df/dright = df/dv * left left->der += der * right->value; right->der += der * left->value; // accumulate left and right subexpressions this->left->diff(); this->right->diff(); } }; // Class BDiv inherites from the class B // It redefines virtual function diff for division class BDiv : public B { public: BDiv(B& _left, B& _right) : B(_left.value / _right.value, &_left, &_right) {} void diff() { // we have computed df/dv and stored in the member variable 'der' // if v=left/right then // df/dleft = (df/dv) / right // and // df/dright = - df/dv * v/right left->der += der/right->value; right->der -= der*value/right->value; // accumulate left and right subexpressions left->diff(); right->diff(); } }; // -------------------------------------------------------- // operator overloading. // Each operator takes two subexpressions as arguments // and returns a reference to new created node in DAG. // In C++ we can return reference to the base class B // even in the object is of the inherited type. B& operator+(B& left, B& right) { return *(new BSum(left,right)); } B& operator-(B& left, B& right) { return *(new BDif(left,right)); } B& operator*(B& left, B& right) { return *(new BMul(left,right)); } B& operator/(B& left, B& right) { return *(new BDiv(left,right)); } int main() { // We compute the gradient of // f(x,y) = (x+y+1)/(x*y-1) // see pdf file with materials // and compare with the example for forward propagation B x(2), y(-2), one(1); B& f = (x+y+one)/(x*y-one); // Now, f.value store f(2,-2) std::cout << "f=" << f.value << std::endl; // We initialize the derivative at the root of DAG, i.e. // df/df = 1 f.der=1; // and start reverse accumulation of the derivatives f.diff(); // now x.der and y.der contain df/dx and df/fy, respectively std::cout << "df/dx=" << x.der << std::endl; std::cout << "df/dy=" << y.der << std::endl; }