/* package whatever; // don't place package name! */
import java.util.*;
import java.lang.*;
import java.io.*;
/* Name of the class has to be "Main" only if the class is public. */
public class Main
{
{
List<Integer> nums =
arrays.
asList(4,
1,
3,
2);
system.
out.
println(findTotalImbalance
(nums
));
}
public static long findTotalImbalance(List<Integer> rank){
long totalImbalance = 0;
int index = 0;
TreeSet<Integer> groupSet = new TreeSet<>();
while ( index < rank.size()-1) {
groupSet.clear();
groupSet.add(rank.get(index));
long currentImbalance = 0;
for (int i = index + 1; i < rank.size();i++) {
int currentRank = rank.get(i);
groupSet.add(currentRank);
integer lowestRank = groupSet.
lower(currentRank
);
integer highestRaank = groupSet.
higher(currentRank
);
if (lowestRank == null) {
currentImbalance += ((highestRaank-currentRank) >1 ? 1 : 0);
}
else if (highestRaank == null) {
currentImbalance += (((currentRank-lowestRank) > 1 ) ? 1 : 0);
}
else {
currentImbalance += (highestRaank-lowestRank) > 1 ? -1 : 0;
currentImbalance += (((highestRaank-currentRank) > 1 ) ? 1 : 0);
currentImbalance += ((currentRank-lowestRank)) > 1 ? 1 : 0;
}
totalImbalance += currentImbalance;
}
index ++;
}
return totalImbalance;
}
}