这个问题的核心在于用线段树找最长的连续的1 的位置,其他就基本操作了,不是很难想,只是不好写而已,具体看代码吧,没什么巧妙的东西,和那个贪婪大陆一样
#include<iostream> using namespace std; typedef long long ll; const int maxn = 6e5+1111; struct Node{ ll ans,LL,RR; ll lazy; ll sum; }cns[maxn*4],tree[maxn*4];//一个维护最长连续的1,一个维护区间和 int push(int node,int be,int en){ int mid = be + en >> 1; int l = node*2; int r = node*2+1; if(tree[node].lazy){ tree[l].lazy += tree[node].lazy; tree[r].lazy += tree[node].lazy; tree[l].sum += 1LL*(mid - be + 1)*tree[node].lazy; tree[r].sum += 1LL*(en - mid)*tree[node].lazy; tree[node].lazy = 0; } return 0; } int add(int node,int be,int en,int LL,int RR,ll val){ int mid = be + en >> 1; int l = node*2; int r = node*2+1; if(LL <= be && en <= RR){ tree[node].sum += 1LL*(en - be + 1)*val; tree[node].lazy += val; return 0; } push(node,be,en); if(LL <= mid) add(l,be,mid,LL,RR,val); if(RR > mid) add(r,mid+1,en,LL,RR,val); tree[node].sum = tree[l].sum + tree[r].sum; } //--------------------------区间和 int update(int node,int be,int en,int i,int val){//最长连续的1 int mid = be + en >> 1; int l = node*2; int r = node*2+1; if(be == en){ cns[node].ans = val; cns[node].LL = cns[node].RR = val; return 0; } if(i <= mid) update(l,be,mid,i,val); else update(r,mid+1,en,i,val); cns[node].ans = max(cns[l].ans,cns[r].ans); cns[node].ans = max(cns[node].ans,cns[l].RR + cns[r].LL); cns[node].LL = cns[l].LL; cns[node].RR = cns[r].RR; if(cns[l].ans == mid - be + 1) cns[node].LL = mid - be + 1 + cns[r].LL; if(cns[r].ans == en - mid) cns[node].RR = en - mid + cns[l].RR; } int id= 0; int find(int node,int be,int en,int x){//1---n中长度比x大,且最左边的位置坐标 int mid = be + en >> 1; int l = node*2; int r = node*2+1; if(be == en){ id = be; return 0; } if(cns[l].ans >= x) find(l,be,mid,x); else if(cns[l].RR + cns[r].LL >= x){ id = mid - cns[l].RR + 1; return 0; } else if(cns[r].ans >= x){ find(r,mid+1,en,x); } return 0; } int a[maxn]; ll cnt[maxn]; int main(){ int n; scanf("%d",&n); for(int i=1;i<=n;i++){ int x; scanf("%01d",&x); a[i] = x; if(x == 1) update(1,1,n,i,1); add(1,1,n,i,i,cns[1].ans); } ll d = 0; for(int i=n;i>=1;i--){ if(a[i]){ d++; cnt[i] = d; } else{ d = 0; cnt[i] = 0; } } ll ans = 0; for(int i=1;i<=n;i++){ ans += tree[1].sum; if(a[i] == 1){ int len = cnt[i]; update(1,1,n,i,0); id = 0; find(1,1,n,len); if(id != 0){ add(1,1,n,i,id + len - 2,-1LL); } else{ add(1,1,n,i,n,-1LL); } } } cout<<ans<<endl; return 0; }