输入样例:
3
3 2 1
输出样例:
9
样例解释
首先交换身高为3和2的小朋友,再交换身高为3和1的小朋友,再交换身高为2和1的小朋友,每个小朋友的不高兴程度都是3,总和为9。
解题思路:
Java代码:(树状数组)
import java.io.*;
public class Main {
static final int N = 1000005;
static int []tr = new int[N];//树状数组索引从1开始,表示某段区间的小朋友总数
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
int []w = new int[n + 1];
String[] split = br.readLine().split(" ");
for(int i = 1; i <= n; i++)
w[i] = Integer.parseInt(split[i - 1]) + 1;//整体加一,避开0
int []sum = new int[N];//sum[i]=j :表示第i个小朋友需要交换j次
for(int i = 1; i <= n; i++) {//找出每个小朋友前面有多少人比他高
sum[i] = query(N - 1) - query(w[i]);
add(w[i]);
}
tr = new int[1000005];//重置数组
for(int i = n; i > 0; i--) {//找出每个小朋友前面有多少人比他矮
sum[i] += query(w[i] - 1);
add(w[i]);
}
long ans = 0;
for(int i = 1; i < N; i++) ans += (long)sum[i] * (sum[i] + 1) / 2;
System.out.println(ans);
}
public static int lowbit(int x) {
return x & -x;
}
public static int query(int x) {
int sum = 0;
for(int i = x; i > 0; i -= lowbit(i)) sum += tr[i];
return sum;
}
public static void add(int x) {
for(int i = x; i < N; i += lowbit(i)) tr[i]++;
}
}
Java代码:(线段树)
import java.io.*;
public class Main {
static final int N = 1000005;
static Node []tree = new Node[4 * N];
static class Node{
int l, r;
int sum;
public Node(int l, int r, int sum) {
this.l = l;
this.r = r;
this.sum = sum;
}
}
public static void main(String[] args) throws NumberFormatException, IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
int []w = new int[n + 1];
String[] split = br.readLine().split(" ");
for(int i = 1; i <= n; i++)
w[i] = Integer.parseInt(split[i - 1]) + 1;
build(1, 0, N - 1);
int []sum = new int[N];
for(int i = 1; i <= n; i++) {
sum[i] = query(1, w[i] + 1, N - 1);
modify(1, w[i], 1);
}
build(1, 1, N - 1);//需要重置线段树
for(int i = n; i > 0; i--) {
sum[i] += query(1, 1, w[i] - 1);
modify(1, w[i], 1);
}
long ans = 0;
for(int i = 1; i < N; i++) ans += (long)sum[i] * (sum[i] + 1) / 2;
System.out.println(ans);
}
public static void build(int u, int l, int r) {
if(l == r) tree[u] = new Node(l, r, 0);
else {
tree[u] = new Node(l, r, 0);
int mid = l + r >> 1;
if(l <= mid) build(u << 1, l, mid);
if(mid < r) build(u << 1 | 1, mid + 1, r);
}
}
public static void modify(int u, int x, int v) {
if(tree[u].l == tree[u].r) tree[u].sum += v;
else {
int mid = tree[u].l + tree[u].r >> 1;
if(x <= mid) modify(u << 1, x, v);
if(x > mid) modify(u << 1 | 1, x, v);
tree[u].sum = tree[u << 1].sum + tree[u << 1 | 1].sum;
}
}
public static int query(int u, int l, int r) {
if(l > r) return 0;
if(l <= tree[u].l && tree[u].r <= r) return tree[u].sum;
int mid = tree[u].l + tree[u].r >> 1;
int sum = 0;
if(l <= mid) sum = query(u << 1, l, r);
if(mid < r) sum += query(u << 1 | 1, l, r);
return sum;
}
}