题意描述
给定序列$S$,求出所有四元组$(a,b,c,d)$,满足$a < b$,$c < d$且$S_a < S_b$ ,$S_c > S_d$
要求$a , b , c , d$互不相等。
输入格式
第一行一个整数$n$,表示数组长度。
第二行$n$个整数,分别为$S_1 , S_2 , S_3……S_n$,表示$S$数组。
输出格式
输出仅有一行,包含一个整数,表示最终的答案。
Output ‘s eg
样例解释
合法的$(a , b , c , d)$如下:
$(a , b , c , d) ∈ \lbrace (1 , 2 , 3 , 4) , (1 , 3 , 2 , 4) , (1 , 4 , 2 , 3) \rbrace$
数据范围与约定
对于$20\%$的数据,保证$n < 200$
对于$50\%$的数据,保证$n < 1000$
对于另外$20\%$的数据,保证 $0 ≤ S_i ≤ 1$
对于$100\%$的数据,保证$n < 10^5 , 0 < S_i < 10^9$
分析
一道比较基础的组合计数题。
若题目中没有$a , b , c , d$互不相等,则答案就是逆序对与顺序对的乘积,直接树状数组解决即可。
考虑如何解决上面的限制。
不难发现,不合法的四元组只有四种情况,即$a = c$ , $a = d$ , $b = c$ , $b = d$。且不会出现三个元素相等。
则我们设$(i , S_i)$为每个点的坐标。若$a = c$,我们就只需要求出其右上方的点数乘以右下方的点数即可
剩下三种情况请各位读者自行画图推理。
最后,统计答案直接离散化后套树状数组。
剩下的见代码
Code[Accepted]
#include<iostream> #include<cstdio> #include<cstring> #include<string> #include<stack> #include<queue> #include<deque> #include<vector> #include<algorithm> #include<iomanip> #include<cstdlib> #include<cctype> #include<cmath>
#define ll long long #define I inline #define N 100001
using namespace std;
int n , m;
namespace Tree_array{ ll tree[N];
ll lowbit(ll x){ return x & (- x); }
I void add(ll x , ll k){ for(ll i = x; i <= n; i += lowbit(i)){ tree[i] += k; } }
I ll query(ll x){ ll ans = 0; for(ll i = x; i ; i -= lowbit(i)){ ans += tree[i]; } return ans; }
I void Clear(){ for(int i = 1; i <= n; i ++){ tree[i] = 0; } } }
using namespace Tree_array;
ll a[N] , b[N];
void input(){ cin >> n; for(int i = 1; i <= n; i ++){ cin >> a[i]; b[i] = a[i]; } sort(b + 1 , b + 1 + n); int number = unique(b + 1 , b + 1 + n) - b - 1; for(int i = 1; i <= n; i ++){ a[i] = lower_bound(b + 1 , b + 1 + number , a[i]) - b; } }
ll ls[N] , rs[N] , lb[N] , rb[N]; ll ans , p , q;
int main(){ freopen("a.in" , "r" , stdin); freopen("a.out" , "w" , stdout); input(); for(int i = n; i >= 1; i --){ rs[i] = query(a[i] - 1); p += rs[i]; rb[i] = query(n) - query(a[i]); add(a[i] , 1); } Clear(); for(int i = 1; i <= n; i ++){ ls[i] = query(a[i] - 1); q += ls[i]; lb[i] = query(n) - query(a[i]); add(a[i] , 1); } ans = p * q; for(int i = 1; i <= n; i ++){ ans -= rb[i] * rs[i] + rb[i] * lb[i] + rs[i] * ls[i] + lb[i] * ls[i]; } cout << ans << "\n"; return 0; }
|
THE END