2 条题解
-
0
#include <bits/stdc++.h> using namespace std; namespace z { #define int long long const int N = 5e5 + 5; int n, a[N], pmx[N], pmn[N], smx[N], smn[N], ans, pre[N], suf[N]; int pcx[N], pcn[N], scx[N], scn[N]; void main() { ios::sync_with_stdio(false); cin.tie(nullptr);cout.tie(nullptr); cin >> n; memset(pmn, 0x3f, sizeof(pmn)); memset(smn, 0x3f, sizeof(smn)); for(int i = 1; i <= n; i++) { cin >> a[i]; pmx[i] = max(pmx[i - 1], a[i]); pmn[i] = min(pmn[i - 1], a[i]); pre[i] = pre[i - 1] + a[i]; } int mx = 0; for(int i = 1; i <= n; i++) { if(a[i] > mx) pcx[i] = mx, mx = a[i]; else pcx[i] = a[i] > pcx[i - 1] ? a[i] : pcx[i - 1]; } int mn = 2e9; for(int i = 1; i <= n; i++) { if(a[i] < mn) pcn[i] = mn, mn = a[i]; else pcn[i] = a[i] < pcn[i - 1] ? a[i] : pcn[i - 1]; } mx = 0; for(int i = n; i >= 1; i--) { if(a[i] > mx) scx[i] = mx, mx = a[i]; else scx[i] = a[i] > scx[i + 1] ? a[i] : scx[i + 1]; } mn = 2e9; for(int i = n; i >= 1; i--) { if(a[i] < mn) scn[i] = mn, mn = a[i]; else scn[i] = a[i] < scn[i + 1] ? a[i] : scn[i + 1]; } for(int i = n; i >= 1; i--) { smx[i] = max(smx[i + 1], a[i]); smn[i] = min(smn[i + 1], a[i]); suf[i] = suf[i + 1] + a[i]; } int maxn, minn, cmax, cmin, sum; for(int i = 0; i <= n / 2; i++) { sum = pre[i] + suf[n / 2 + i + 1]; if(i != 0 && i != n / 2) { minn = min(pmn[i], smn[n / 2 + i + 1]); maxn = max(pmx[i], smx[n / 2 + i + 1]); cmax = max({min(pmx[i], smx[n / 2 + i + 1]), pcx[i], scx[n / 2 + i + 1]}); cmin = min({max(pmn[i], smn[n / 2 + i + 1]), pcn[i], scn[n / 2 + i + 1]}); } else { minn = i ? pmn[i] : smn[n / 2 + i + 1]; maxn = i ? pmx[i] : smx[n / 2 + i + 1]; cmax = i ? pcx[i] : scx[n / 2 + i + 1]; cmin = i ? pcn[i] : scn[n / 2 + i + 1]; } if((minn + minn + n / 2 - 1) * n / 4 == sum) { ans += (n / 2 - 1) * n / 2; if(minn - 1) ans++; if(maxn + 1 <= n) ans++; continue; } if(cmax == minn + n / 2 - 1) ans++; if(cmax == minn + n / 2 - 2) ans++, ans += minn != 1; if(cmin == maxn - n / 2 + 1) ans++; if(cmin == maxn - n / 2 + 2) ans++, ans += maxn < n; } cout << ans << '\n'; } #undef int } int main() { z::main(); return 0; }
-
0
Answer is here!
#include <cstdio> #include <algorithm> #include <iostream> #include <cmath> #include <cstring> #include <vector> #include <set> using namespace std; using ll = long long; const int N = 2e5 + 10; const int MOD = 998244353; int n, m, a[N], x, y; ll ans; set<int> st; // 统计没有关键点的方案数 void Calc(int len, int k) { if(len >= m) x += k * (len - m + 1); } // 统计只有一个关键点的方案数,it 是这个关键点 void Get(set<int>::iterator it, int k) { if(*it < 1 || *it > n) return ; int l = *prev(it), p = *it, r = *next(it); if(r - l - 1 < m) return ; int lef = max(l + 1, p - m + 1); int rig = min(p, r - m); y += k * (rig - lef + 1); } void Add(int k) { auto it = st.lower_bound(k); int l = *prev(it), r = *it; Calc(r - l - 1, -1), Get(prev(it), -1); Calc(k - l - 1, 1), Get(it, -1); Calc(r - k - 1, 1); st.insert(k); it = st.find(k); Get(prev(it), 1); Get(next(it), 1); Get(it, 1); } void Del(int k) { auto it = st.find(k); int l = *prev(it), r = *next(it); Calc(r - l - 1, 1), Get(prev(it), -1); Calc(k - l - 1, -1), Get(next(it), -1); Calc(r - k - 1, -1), Get(it, -1); auto tl = prev(it), tr = next(it); st.erase(it); Get(tl, 1), Get(tr, 1); } void Solve() { cin >> n, m = n / 2; for(int i = 1; i <= n; ++i) cin >> a[i]; st.insert(0), st.insert(n + 1); x = m + 1; for(int i = 1; i <= m; ++i) Add(a[i]); for(int i = m; i <= n; ++i) { ans += 1ll * x * m * (m - 1) + y; if(i != n) { Add(a[i + 1]); Del(a[i - m + 1]); } } printf("%lld\n", ans); } int main() { cin.tie(0)->sync_with_stdio(0); int t = 1; //cin >> t; while(t--) Solve(); return 0; } # Thank you very much!
- 1
信息
- ID
- 77
- 时间
- 3000ms
- 内存
- 512MiB
- 难度
- 6
- 标签
- 递交数
- 281
- 已通过
- 79
- 上传者