目录

AtCoder-Beginner-Contest-397-F题线段树

AtCoder Beginner Contest 397 F题(线段树)

题目链接

思路

L [ i ] L[i]

L

[

i

] 表示区间

[ 1 , i ] [1,i]

[

1

,

i

] 中有多少个不同的数。

R [ i ] R[i]

R

[

i

] 表示区间

[ i , n ] [i,n]

[

i

,

n

] 中有多少个不同的数。

n x t [ i ] nxt[i]

n

x

t

[

i

] 表示下一个与

a [ i ] a[i]

a

[

i

] 的值相同的数的下标。

假设我们现在已经枚举了第一个区间与第二个区间的分界点

i i

i ,

i i

i 表示第二个区间的起始端点。则第一个区间

[ 1 , i ] [1,i]

[

1

,

i

] 对答案产生的贡献为

L [ i ] L[i]

L

[

i

] 。

考虑第二个分界点

j j

j (

j j

j 表示第二个区间的终止端点)。

在区间

[ i , n ] [i,n]

[

i

,

n

] 中的数字,有两种情况:

  • 1,这个数字只出现了一次,无论放到第二个区间还是第三个区间,都只会产生

    1 1

    1 的贡献。

  • 2,这个数字出现了多次,那么如果

    j j

    j 选在了这个数字出现两次的中间,就会让

    [ i , j ] [i,j]

    [

    i

    ,

    j

    ] 和

    [ j + 1 , n ] [j+1,n]

    [

    j

1

,

n

] 都产生

1 1

1 的贡献。

所以,对于区间

[ i , n ] [i,n]

[

i

,

n

] 中的数字,至少产生

1 1

1 个贡献,还有可能多产生

1 1

1 个贡献。

所以,最终的答案就是

m a x ( L [ i − 1 ] + R [ i ] + 多产生的贡献 ) max(L[i-1]+R[i]+多产生的贡献)

ma

x

(

L

[

i

1

]

R

[

i

]

多产生的贡献

) 。

现在开始考虑如何计算贡献:

对于区间

[ i , n ] [i,n]

[

i

,

n

] 上的下标为

k k

k 的数字,如果其出现了多次,则下一次出现的下标为

n x t [ k ] nxt[k]

n

x

t

[

k

] 。如果第二个区间与第三个区间的分断点

j j

j 出现在

k k

k 和

n x t [ k ] nxt[k]

n

x

t

[

k

] 之间,则

a [ k ] a[k]

a

[

k

] 这个数字会额外产生

1 1

1 个贡献。

因此,对于额外产生的贡献,我们可以用线段树实现区间加法,并维护区间最大值。

时间复杂度:

O ( n l o g n ) O(nlogn)

O

(

n

l

o

g

n

) 。

代码

// #pragma GCC optimize("O2")
// #pragma GCC optimize("O3")
#include <bits/stdc++.h>

using namespace std;

#define int long long
#define double long double
#define endl "\n"

typedef long long i64;
typedef unsigned long long u64;
typedef pair<int, int> pii;

const int N = 3e5 + 5, M = 1e6 + 5;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f3f3f3f3f;
const double eps = 1e-6;

std::mt19937 rnd(time(0));

int n;
int a[N], L[N], R[N], nxt[N];
struct SegmentTree
{
    struct node
    {
        int l, r, maxx, tag;
    };
    vector<node>tree;

    SegmentTree() {}
    SegmentTree(int n) {tree.resize(n * 4 + 1);}

    void pushup(int u)
    {
        auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
        root.maxx = max(left.maxx , right.maxx);
    }

    void pushdown(int u)
    {
        auto &root = tree[u], &left = tree[u << 1], &right = tree[u << 1 | 1];
        if (root.tag)
        {
            left.tag += root.tag;
            right.tag += root.tag;
            left.maxx += root.tag;
            right.maxx += root.tag;
            root.tag = 0;
        }
    }

    void build(int u, int l, int r)
    {
        auto &root = tree[u];
        root = {l, r};
        if (l == r)
        {
            root.maxx = 0;
            return;
        }
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }

    void modify(int u, int l, int r, int val)
    {
        auto &root = tree[u];
        if (root.l >= l && root.r <= r)
        {
            root.maxx += val;
            root.tag += val;
            return;
        }
        pushdown(u);
        int mid = root.l + root.r >> 1;
        if (l <= mid) modify(u << 1, l, r, val);
        if (r > mid) modify(u << 1 | 1, l, r, val);
        pushup(u);
    }

    int query(int u, int l, int r)
    {
        auto &root = tree[u];
        if (root.l >= l && root.r <= r)
        {
            return root.maxx;
        }
        pushdown(u);
        int mid = root.l + root.r >> 1;
        int res = 0;
        if (l <= mid) res = query(u << 1, l, r);
        if (r > mid) res = max(res, query(u << 1 | 1, l, r));
        return res;
    }
};
void solve(int test_case)
{
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
    }
    set<int>st1, st2;
    for (int i = 1, j = n; i <= n; i++, j--)
    {
        st1.insert(a[i]), st2.insert(a[j]);
        L[i] = st1.size(), R[j] = st2.size();
    }
    vector<int>mp(n + 1, 0);
    for (int i = n; i >= 1; i--)
    {
        nxt[i] = mp[a[i]];
        mp[a[i]] = i;
    }
    SegmentTree smt(n);
    smt.build(1, 1, n);
    for (int i = n; i >= 2; i--)
    {
        if (nxt[i] > i)
        {
            smt.modify(1, i, nxt[i] - 1, 1);
        }
    }
    int ans = 0;
    for (int i = 2; i < n; i++)
    {
        int res = L[i - 1] + R[i];
        res += smt.query(1, i, n - 1);
        ans = max(ans, res);
        if (nxt[i] > i)
        {
            smt.modify(1, i, nxt[i] - 1, -1);
        }
    }
    cout << ans << endl;
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    int test = 1;
    // cin >> test;
    for (int i = 1; i <= test ; i++)
    {
        solve(i);
    }
    return 0;
}