2 条题解

  • 0
    @ 2025-5-19 21:16:45
    #include <bits/stdc++.h>
    using namespace std;
    
    namespace z {
    
    #define int long long
    const int N = 5e5 + 5;
    int n, a[N], pmx[N], pmn[N], smx[N], smn[N], ans, pre[N], suf[N];
    int pcx[N], pcn[N], scx[N], scn[N];
    void main() {
    
        ios::sync_with_stdio(false);
        cin.tie(nullptr);cout.tie(nullptr);
        cin >> n;
        memset(pmn, 0x3f, sizeof(pmn));
        memset(smn, 0x3f, sizeof(smn));
        for(int i = 1; i <= n; i++) {
            cin >> a[i];
            pmx[i] = max(pmx[i - 1], a[i]);
            pmn[i] = min(pmn[i - 1], a[i]);
            pre[i] = pre[i - 1] + a[i];
        }
        int mx = 0;
        for(int i = 1; i <= n; i++) {
            if(a[i] > mx) pcx[i] = mx, mx = a[i];
            else pcx[i] = a[i] > pcx[i - 1] ? a[i] : pcx[i - 1];
        }
        int mn = 2e9;
        for(int i = 1; i <= n; i++) {
            if(a[i] < mn) pcn[i] = mn, mn = a[i];
            else pcn[i] = a[i] < pcn[i - 1] ? a[i] : pcn[i - 1];
        }
        mx = 0;
        for(int i = n; i >= 1; i--) {
            if(a[i] > mx) scx[i] = mx, mx = a[i];
            else scx[i] = a[i] > scx[i + 1] ? a[i] : scx[i + 1];
        }
        mn = 2e9;
        for(int i = n; i >= 1; i--) {
            if(a[i] < mn) scn[i] = mn, mn = a[i];
            else scn[i] = a[i] < scn[i + 1] ? a[i] : scn[i + 1];
        }
        for(int i = n; i >= 1; i--) {
            smx[i] = max(smx[i + 1], a[i]);
            smn[i] = min(smn[i + 1], a[i]);
            suf[i] = suf[i + 1] + a[i];
        }
    
        int maxn, minn, cmax, cmin, sum;
        for(int i = 0; i <= n / 2; i++) {
            sum = pre[i] + suf[n / 2 + i + 1];
            if(i != 0 && i != n / 2) {
                minn = min(pmn[i], smn[n / 2 + i + 1]);
                maxn = max(pmx[i], smx[n / 2 + i + 1]);
                cmax = max({min(pmx[i], smx[n / 2 + i + 1]), pcx[i], scx[n / 2 + i + 1]});
                cmin = min({max(pmn[i], smn[n / 2 + i + 1]), pcn[i], scn[n / 2 + i + 1]});
            } else {
                minn = i ? pmn[i] : smn[n / 2 + i + 1];
                maxn = i ? pmx[i] : smx[n / 2 + i + 1];
                cmax = i ? pcx[i] : scx[n / 2 + i + 1];
                cmin = i ? pcn[i] : scn[n / 2 + i + 1];
            }
            if((minn + minn + n / 2 - 1) * n / 4 == sum) {
                ans += (n / 2 - 1) * n / 2;  
                if(minn - 1) ans++;
                if(maxn + 1 <= n) ans++;
                continue;
            } 
            if(cmax == minn + n / 2 - 1) ans++;
            if(cmax == minn + n / 2 - 2) ans++, ans += minn != 1;
            if(cmin == maxn - n / 2 + 1) ans++;
            if(cmin == maxn - n / 2 + 2) ans++, ans += maxn < n;
        }
        cout << ans << '\n';
    }
    
    #undef int
    
    }
    
    
    int main()
    {
        z::main();
        return 0;
    }
    

    O(n)O(n)

    • 0
      @ 2024-10-25 21:55:43

      Answer is here!

      #include <cstdio>
      #include <algorithm>
      #include <iostream>
      #include <cmath>
      #include <cstring>
      #include <vector>
      #include <set>
      using namespace std;
      using ll = long long;
      const int N = 2e5 + 10;
      const int MOD = 998244353;
      int n, m, a[N], x, y;
      ll ans;
      set<int> st;
      // 统计没有关键点的方案数
      void Calc(int len, int k) { 
          if(len >= m) x += k * (len - m + 1);
      }
      // 统计只有一个关键点的方案数,it 是这个关键点
      void Get(set<int>::iterator it, int k) { 
          if(*it < 1 || *it > n) return ;
          int l = *prev(it), p = *it, r = *next(it);
          if(r - l - 1 < m) return ;
          int lef = max(l + 1, p - m + 1);
          int rig = min(p, r - m);
          y += k * (rig - lef + 1);
      }
      void Add(int k) {
          auto it = st.lower_bound(k);
          int l = *prev(it), r = *it;
          Calc(r - l - 1, -1), Get(prev(it), -1);
          Calc(k - l - 1, 1), Get(it, -1);
          Calc(r - k - 1, 1);
          st.insert(k);
          it = st.find(k);
          Get(prev(it), 1);
          Get(next(it), 1);
          Get(it, 1);
      }
      void Del(int k) {
          auto it = st.find(k);
          int l = *prev(it), r = *next(it);
          Calc(r - l - 1, 1), Get(prev(it), -1);
          Calc(k - l - 1, -1), Get(next(it), -1);
          Calc(r - k - 1, -1), Get(it, -1);
          auto tl = prev(it), tr = next(it);
          st.erase(it);
          Get(tl, 1), Get(tr, 1);
      }
      void Solve() {
          cin >> n, m = n / 2;
          for(int i = 1; i <= n; ++i) cin >> a[i];
          st.insert(0), st.insert(n + 1);
          x = m + 1;
          for(int i = 1; i <= m; ++i) 
              Add(a[i]);
          for(int i = m; i <= n; ++i) {
              ans += 1ll * x * m * (m - 1) + y;
              if(i != n) {
                  Add(a[i + 1]);
                  Del(a[i - m + 1]);
              }
          }
          printf("%lld\n", ans);
      }
      int main() {
          cin.tie(0)->sync_with_stdio(0);
          int t = 1; //cin >> t;
          while(t--) Solve();
          return 0;
      }
      # Thank you very much!
      
      • 1

      信息

      ID
      77
      时间
      3000ms
      内存
      512MiB
      难度
      6
      标签
      递交数
      281
      已通过
      79
      上传者