Nowadays the application of automatic differentiation technology has greatly facilitated people's implementation and training of deep learning algorithms.Our task is to create an automatic differentiation program for algebraic expressions.
Input an infix expression composed of brackets,operators including power(^), multiplication(*), division(/), addition(+) and subtraction(-), variables(strings of lowercase English letters) and literal constant.
For each variable in the expression,output an arithmetic expression that represents the derivative of the input expression with respect to the variable.
Arrange the output in the lexicographical order of the variables.
We create a Struct to store the information of the nodes in the expression tree.Every nodes in the expression tree has three components.One is a string indicating the value of the node (may be an operator,a variable name or a constant number) and the other two are two pointers respectively indicating leftchild and rightchild of the node.
We need the data structure Stack to assist building the expression tree.Nodestack stores the subtrees created from the expression while opstack stores all the operators.Pseudo-code is shown below:
for(everychintheexpression){if(ch=='(')push'('intotheopstack;if(ch==')'){//build a subtree from the expression in the bracketspopoperatorfromopstack;while(operator!='('){popa,bfromnodestack;createsubtree(a,operator,b)pushthesubtreeintonodestack;popoperatorfromopstack;}pop'('fromopstack;}if(chisaletteroranumber)addchtoavariablename;//it must be a component of variable name/a numberif(chisanoperator){popa,bfromnodestack;createsubtree(a,operator,b)pushthesubtreeintonodestack;}}while(opstackisnotempty){popoperatorfromopstack;popa,bfromnodestack;createsubtree(a,operator,b)pushthesubtreeintonodestack;}poprootfromnodestack;//root is the root of expression tree;
if(root==NULL)returnNULL;if(root->valueisaconstantnumber)//F(x)=c,F'(x)=0returnNULL;if(root->value==var)//F(x)=x,F'(x)=1returnnode("1");if(root->valueisanoperator){left=differentiate(root->left,var);//get the derivative of left expressionright=differentiate(root->right,var);//get the derivative of right expressionif(root->value=="+")//F(x)=f(x)+g(x),F'(x)=f'(x)+g'(x)returnsubtree(left,"+",right);if(root->value=="-")//F(x)=f(x)-g(x),F'(x)=f'(x)-g'(x)returnsubtree(left,"-",right);if(root->value=="*"){//F(x)=f(x)*g(x),F'(x)=f'(x)*g(x)+f(x)*g'(x)buildsubtrees1(left,"*",root->right);buildsubtrees2(root->left,"*",right);returnsubtrees3(s1,"+",s2);}if(root->value=="/"){//F(x)=f(x)/g(x),F'(x)=[f'(x)*g(x)-f(x)*g'(x)]/g(x)^2buildsubtrees1(left,"*",root->right);buildsubtrees2(root->left,"*",right);buildsubtrees3(s1,"-",s2);buildsubtrees4(root->right,"^","2");returnsubtrees5(s3,"/",s4);}if(root->value=="^"){//F(x)=f(x)^[g(x)],F'(x)=[g'(x)*lnf(x)+g(x)*f'(x)/f(x)]*f(x)^g(x)buildsubtrees1(right,"*",ln(root->left));buildsubtrees2(root->right,"*",left);buildsubtrees3(s2,"/",root->left);buildsubtrees4(s1,"+",s3);buildsubtrees5(root->left,"^",root->right);returnsubtrees6(s4,"*",s5);}}
For step 1 : Creating the expression tree, we traverse all operators, variables and numbers in the expression, so the time complexity of step 1 is \(O(n)\) (Assume the length of the expression is \(n\)).
For step 2 : Store all variable names and sort, we traverse the expression again and sort all variable names. The time complexity of Quick_sort Algorithm is \(O(mlogm)\), so the time complexity of step 2 is \(O(n)+O(mlogm)=O(n+mlogm)\) (Assume there are \(m\) variable names in the expression)
For step 3 : Output, firstly we differentiate all \(m\) variables and use the function \(m\) times.The function differentiate traverse all operators, variables and numbers in the expression tree and extend constant number of nodes, so the time complexity of step 3 is \(O(mp)\) (Assume there are \(p\) nodes in the expression tree)
To sum up, the total time complexity of the program is \(O(n+mlogm+mp)\)(\(n\) indicates the length of the expression, \(m\) indicates the number of the variables, \(p\) indicates the number of nodes in the expression tree)
The whole program construct \(c_1*p\) nodes(\(c_1\) is a constant) , \(c_2\) strings(\(c_2\) is a constant) to store the expression and a string array(but the total length is the length of the expression). So the total space complexity is \(O(p+n)\) (\(n\) indicates the length of the expression, \(p\) indicates the number of nodes in the expression tree)
The function of this program is still limited, for it can't support mathmatic functions such as \(\sin x,\cos x,\tan x,\ln x,\log(x,y),\exp(x)...\) . Also it can't simplify both the input expression or the output expression. It still needs to be improved.
#include<iostream>#include<string>#include<algorithm>usingnamespacestd;structTreeNode{stringvalue;TreeNode*left;TreeNode*right;};structStack{//define a stack to build the expression tree from inorder expressionTreeNode*data[100];inttop;voidpush(TreeNode*node){//pushing the node into the stackdata[++top]=node;}TreeNode*pop(){//pop and acquire the top element in the stackreturndata[top--];}TreeNode*visit(){//acquire the top element in the stackreturndata[top];}boolempty(){//judge whether the stack is emptyreturntop==-1;}};intprecedence(charop){//get the precedence of the operatorswitch(op){case'^':return3;break;case'*':return2;break;case'/':return2;break;case'+':return1;break;case'-':return1;break;default:return-1;//define the precedence of a number or a variable is the lowest}}boolcmp(strings1,strings2){//the assisting function for quick_sort algorithmreturns1<s2;}boolisnumber(strings){//determine whether s is a numberfor(inti=0;i<(int)s.length();i++)if((s[i]<'0')||(s[i]>'9'))returnfalse;returntrue;}boolisoperator(strings){//determine whether s is an operatorreturn(s[0]=='+')||(s[0]=='-')||(s[0]=='*')||(s[0]=='/')||(s[0]=='^');}TreeNode*createnode(strings){//function for creating a new treenodeTreeNode*temp=newTreeNode;temp->value=s;temp->left=NULL;temp->right=NULL;returntemp;}TreeNode*buildtree(stringexpression){//function for building a treeStacknodestack,opstack;//nodestack stores variables,numbers;opstack stores operatorsnodestack.top=opstack.top=-1;//set the two topsstringvar;//collect the variable name/numberfor(inti=0;i<(int)expression.length();i++){if(expression[i]=='(')//if we encounter left bracket,push it into opstack for encountering right bracketopstack.push(createnode(string(1,expression[i])));elseif(isalnum(expression[i]))//if expression[i] is a letter or number,collect it into varvar+=expression[i];elseif(expression[i]==')'){//if we encounter right bracket,then create expression tree in the bracketsif(!var.empty())//push the new variable into the stack{nodestack.push(createnode(var));var.clear();//clear the string and start over}while((!opstack.empty())&&(opstack.visit()->value!="(")){//create until we encounter left bracketTreeNode*op=opstack.pop();//get the operatorTreeNode*rightnode=nodestack.pop();//get the numbersTreeNode*leftnode=nodestack.pop();op->left=leftnode;op->right=rightnode;//create the treenodestack.push(op);//push the new tree back into the stack}opstack.pop();//pop the left bracket}else{//if expression[i] is an operator,then get two numbers from nodestack and create the treeif(!var.empty()){nodestack.push(createnode(var));var.clear();}while((!opstack.empty())&&(precedence(opstack.visit()->value[0])>=precedence(expression[i]))){//if the precedence of previous operator is greater,handle the operator firstTreeNode*op=opstack.pop();TreeNode*rightnode=nodestack.pop();TreeNode*leftnode=nodestack.pop();op->left=leftnode;op->right=rightnode;nodestack.push(op);}opstack.push(createnode(string(1,expression[i])));//push the new operator into opstack}}/*there is still a variable in var,a operator in opstack,a variable in nodestack,create the tree*/if(!var.empty()){nodestack.push(createnode(var));var.clear();}while(!opstack.empty()){TreeNode*op=opstack.pop();TreeNode*rightnode=nodestack.pop();TreeNode*leftnode=nodestack.pop();op->left=leftnode;op->right=rightnode;nodestack.push(op);}returnnodestack.pop();//return the headnode}TreeNode*addnodes(TreeNode*left,TreeNode*right){//merge leftnode,rightnode and operator "+"if(left==NULL)//if left is NULL then there is no need to print "+"(in case "+a")returnright;if(right==NULL)//if right is NULL then there is no need to print "+"(in case "a+")returnleft;TreeNode*temp=createnode("+");temp->left=left;temp->right=right;returntemp;}TreeNode*subnodes(TreeNode*left,TreeNode*right){//merge leftnode,rightnode and operator "-"if(right==NULL)//if right is NULL then there is no need to print "-"(in case "a-")returnleft;TreeNode*temp=createnode("-");temp->left=left;temp->right=right;returntemp;}TreeNode*mulnodes(TreeNode*left,TreeNode*right){//merge leftnode,rightnode and operator "*"if((left==NULL)||(right==NULL))//if right or left is NULL then the whole result is 0returnNULL;TreeNode*temp=createnode("*");temp->left=left;temp->right=right;returntemp;}TreeNode*divnodes(TreeNode*left,TreeNode*right){//merge leftnode,rightnode and operator "/"if((left==NULL)||(right==NULL))//if right or left is NULL then the whole result is 0returnNULL;TreeNode*temp=createnode("/");temp->left=left;temp->right=right;returntemp;}TreeNode*powernodes(TreeNode*left,TreeNode*right){//merge leftnode,rightnode and operator "^"TreeNode*temp=createnode("^");temp->left=left;temp->right=right;returntemp;}TreeNode*differentiate(TreeNode*root,stringvar){if(root==NULL)returnNULL;if(isnumber(root->value))//the derivative of constant number is 0returnNULL;elseif(root->value==var)//the derivative of the var is 1returncreatenode("1");else{//root->value is an operatorTreeNode*leftchild=differentiate(root->left,var);//get the derivative of leftchildTreeNode*rightchild=differentiate(root->right,var);//get the derivative of rightchildif(root->value=="+")//F(x)=f(x)+g(x),F'(x)=f'(x)+g'(x)returnaddnodes(leftchild,rightchild);if(root->value=="-")//F(x)=f(x)-g(x),F'(x)=f'(x)-g'(x)returnsubnodes(leftchild,rightchild);if(root->value=="*")//F(x)=f(x)*g(x),F'(x)=f'(x)*g(x)+f(x)*g'(x)returnaddnodes(mulnodes(leftchild,root->right),mulnodes(root->left,rightchild));if(root->value=="/"){//F(x)=f(x)/g(x),F'(x)=[f'(x)*g(x)-f(x)*g'(x)]/g(x)^2TreeNode*term1=mulnodes(leftchild,root->right);TreeNode*term2=mulnodes(root->left,rightchild);TreeNode*term3=powernodes(root->right,createnode("2"));returndivnodes(subnodes(term1,term2),term3);}if(root->value=="^"){//F(x)=f(x)^[g(x)],F'(x)=[g'(x)*lnf(x)+g(x)*f'(x)/f(x)]*f(x)^g(x)TreeNode*term1=mulnodes(rightchild,createnode("ln("+root->left->value+")"));TreeNode*term2=divnodes(mulnodes(root->right,leftchild),root->left);TreeNode*term3=powernodes(root->left,root->right);returnmulnodes(addnodes(term1,term2),term3);}}}voidprinttree(TreeNode*root){//print the expression tree back to inorder expressionif(root==NULL)return;if(root->left!=NULL){if(isoperator(root->value)&&isoperator(root->left->value)){boolflag=precedence(root->value[0])>precedence(root->left->value[0]);if(flag)//if root's operator is higher than root's leftchild's operator,then expression in leftchild need bracketscout<<"(";printtree(root->left);if(flag)cout<<")";}elseprinttree(root->left);}cout<<root->value;if(root->right!=NULL){if(isoperator(root->value)&&isoperator(root->right->value)){boolflag=precedence(root->value[0])>precedence(root->right->value[0]);if(flag)//if root's operator is higher than root's rightchild's operator,then expression in rightchild need bracketscout<<"(";printtree(root->right);if(flag)cout<<")";}elseprinttree(root->right);}}intmain(){stringexpression;cin>>expression;/*build the expression tree*/TreeNode*root=buildtree(expression);/*count all the variable names*/stringvariables[100];stringvar;inttotal=0;for(inti=0;i<(int)expression.size();i++){if(isalnum(expression[i]))var+=expression[i];else{if(!var.empty()){boolflag=true;//flag=true means the new variable name is not in the old set of variable namesfor(inti=1;i<=total;i++)if(var==variables[i]){flag=false;break;}if(flag)//add the new variable namevariables[++total]=var;var.clear();}}}if(!var.empty()){boolflag=true;for(inti=1;i<=total;i++)if(var==variables[i]){flag=false;break;}if(flag)variables[++total]=var;var.clear();}/*sort all the variable name in lexicographical order*/sort(variables+1,variables+total+1,cmp);for(inti=1;i<=total;i++)if(!isnumber(variables[i])){cout<<variables[i]<<": ";TreeNode*ans=differentiate(root,variables[i]);//differentiate all variablesprinttree(ans);cout<<endl;}return0;}