解题思路:反向建图,拓扑dfs
AC_Code:
1 #include <bits/stdc++.h> 2 using namespace std; 3 typedef long long ll; 4 const int maxn = 5e4+10; 5 6 int s[maxn][2],_map[maxn]; 7 double ans[maxn],der[maxn],x[maxn];//求导:derivative 8 int in[maxn],f; 9 bool vis[maxn]; 10 int cnt; 11 vector<int>vec; 12 13 int n; 14 15 void dfs(int h){ 16 if( vis[h] ) return ; 17 vis[h] = true; 18 19 if( _map[h]==0 ){//叶子节点,变量 20 ans[h] = x[h]; 21 //根据友情提示的第二条 22 if( h==f ) der[h] = 1; 23 else der[h] = 0; 24 } else if( _map[h]==1 ){//加法求值,求偏导 25 int l=s[h][0], r=s[h][1]; 26 dfs(l); dfs(r); 27 ans[h] = ans[l] + ans[r]; 28 der[h] = der[l] + der[r]; 29 } else if( _map[h]==2 ){//减法求值,求偏导 30 int l=s[h][0], r=s[h][1]; 31 dfs(l), dfs(r); 32 ans[h] = ans[l] - ans[r]; 33 der[h] = der[l] - der[r]; 34 } else if( _map[h]==3 ){//乘法求值,求偏导 35 int l=s[h][0], r=s[h][1]; 36 dfs(l); dfs(r); 37 //复合函数求导根据友情提示第3条 38 ans[h] = ans[l] * ans[r]; 39 der[h] = der[l] * ans[r] + der[r] * ans[l]; 40 } else if( _map[h]==4 ){//指数求导 41 int v = s[h][0]; 42 dfs(v); 43 ans[h] = exp(ans[v]); 44 der[h] = exp(ans[v]) * der[v]; 45 } else if( _map[h]==5 ){//lnx求导 46 int v = s[h][0]; 47 dfs(v); 48 ans[h] = log(ans[v]); 49 der[h] = der[v] / ans[v]; 50 } else if( _map[h]==6 ){//sin求导 51 int v = s[h][0]; 52 dfs(v); 53 ans[h] = sin(ans[v]); 54 der[h] = cos(ans[v]) * der[v]; 55 } 56 } 57 int main() 58 { 59 scanf("%d",&n); 60 for(int i=0;i<n;i++){ 61 int k; scanf("%d",&k); 62 _map[i] = k; 63 if( !k ){ 64 scanf("%lf",&x[i]); 65 vec.push_back(i); 66 } 67 else if( k>=1 && k<=3 ){ 68 int u,v; scanf("%d%d",&u,&v); 69 s[i][0] = u; 70 s[i][1] = v; 71 in[u]++; 72 in[v]++; 73 } 74 else{ 75 int u; scanf("%d",&u); 76 s[i][0] = u; 77 in[u]++; 78 } 79 } 80 int s = 0; 81 for(int i=0;i<n;i++){ 82 if( !in[i] ){ 83 s = i; 84 break; 85 } 86 } 87 queue<double>q; 88 cnt = vec.size(); 89 for(int i=0;i<cnt;i++){ 90 f = vec[i];//指定要偏导的变量 91 memset(vis,false,sizeof(vis)); 92 dfs(s); 93 q.push(der[s]); 94 } 95 printf("%.3f ",ans[s]); 96 bool flag = false; 97 while( !q.empty()){ 98 double h = q.front(); q.pop(); 99 if( !flag ) printf("%.3f",h); 100 else printf(" %.3f",h); 101 flag = true; 102 } 103 printf(" "); 104 return 0; 105 }