Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ public void restore(Object state, BlockEncodingSerdeProvider serdeProvider)
this.segments = myState.segments;
}

private static class IntBigArrayState
public static class IntBigArrayState
implements Serializable
{
private int[][] array;
private int capacity;
private int segments;
public int[][] array;
public int capacity;
public int segments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ public void restore(Object state, BlockEncodingSerdeProvider serdeProvider)
this.segments = myState.segments;
}

private static class LongBigArrayState
public static class LongBigArrayState
implements Serializable
{
private long[][] array;
private int capacity;
private int segments;
public long[][] array;
public int capacity;
public int segments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ protected List<? extends OperatorFactory> createOperatorFactories()
10_000,
Optional.of(new DataSize(16, MEGABYTE)),
JOIN_COMPILER,
false);
false,
Optional.empty());

return ImmutableList.of(tableScanOperator, tpchQuery1Operator, aggregationOperator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ protected List<? extends OperatorFactory> createOperatorFactories()
100_000,
Optional.of(new DataSize(16, MEGABYTE)),
JOIN_COMPILER,
false);
false,
Optional.empty());
return ImmutableList.of(tableScanOperator, aggregationOperator);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import java.math.BigDecimal;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
Expand All @@ -62,6 +63,7 @@
import java.util.OptionalLong;
import java.util.Set;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
Expand Down Expand Up @@ -159,10 +161,10 @@ else if (sample != null) {
try {
Map<String, PartitionStatistics> statisticsSample = statisticsProvider.getPartitionsStatistics(session, schemaTableName, partitionsSample, table);
if (!includeColumnStatistics) {
OptionalDouble averageRows = calculateAverageRowsPerPartition(statisticsSample.values());
Optional<PartitionsRowCount> averageRows = calculatePartitionsRowCount(statisticsSample.values(), partitions.size());
TableStatistics.Builder result = TableStatistics.builder();
if (averageRows.isPresent()) {
result.setRowCount(Estimate.of(averageRows.getAsDouble() * partitions.size()));
result.setRowCount(Estimate.of(averageRows.get().getRowCount()));
}
result.setFileCount(calulateFileCount(statisticsSample.values()));
result.setOnDiskDataSizeInBytes(calculateTotalOnDiskSizeInBytes(statisticsSample.values()));
Expand Down Expand Up @@ -433,14 +435,12 @@ private static TableStatistics getTableStatistics(

checkArgument(!partitions.isEmpty(), "partitions is empty");

OptionalDouble optionalAverageRowsPerPartition = calculateAverageRowsPerPartition(statistics.values());
if (!optionalAverageRowsPerPartition.isPresent()) {
Optional<PartitionsRowCount> optionalRowCount = calculatePartitionsRowCount(statistics.values(), partitions.size());
if (!optionalRowCount.isPresent()) {
return TableStatistics.empty();
}
double averageRowsPerPartition = optionalAverageRowsPerPartition.getAsDouble();
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
int queriedPartitionsCount = partitions.size();
double rowCount = averageRowsPerPartition * queriedPartitionsCount;

double rowCount = optionalRowCount.get().getRowCount();

TableStatistics.Builder result = TableStatistics.builder();
long fileCount = calulateFileCount(statistics.values());
Expand All @@ -457,6 +457,7 @@ private static TableStatistics getTableStatistics(
if (columnHandle.isPartitionKey()) {
tableColumnStatistics = statsCache.get(partitions.get(0).getTableName().getTableName() + columnName);
if (tableColumnStatistics == null || invalidateStatsCache(tableColumnStatistics, Estimate.of(rowCount), fileCount, totalOnDiskSize)) {
double averageRowsPerPartition = optionalRowCount.get().getAverageRowsPerPartition();
columnStatistics = createPartitionColumnStatistics(columnHandle, columnType, partitions, statistics, averageRowsPerPartition, rowCount);
TableStatistics tableStatistics = new TableStatistics(Estimate.of(rowCount), fileCount, totalOnDiskSize, ImmutableMap.of());
tableColumnStatistics = new TableColumnStatistics(tableStatistics, columnStatistics);
Expand Down Expand Up @@ -485,15 +486,44 @@ private static boolean invalidateStatsCache(TableColumnStatistics tableColumnSta
}

@VisibleForTesting
static OptionalDouble calculateAverageRowsPerPartition(Collection<PartitionStatistics> statistics)
static Optional<PartitionsRowCount> calculatePartitionsRowCount(Collection<PartitionStatistics> statistics, int queriedPartitionsCount)
{
return statistics.stream()
long[] rowCounts = statistics.stream()
.map(PartitionStatistics::getBasicStatistics)
.map(HiveBasicStatistics::getRowCount)
.filter(OptionalLong::isPresent)
.mapToLong(OptionalLong::getAsLong)
.peek(count -> verify(count >= 0, "count must be greater than or equal to zero"))
.average();
.toArray();
int sampleSize = statistics.size();
// Sample contains all the queried partitions, estimate avg normally
if (rowCounts.length <= 2 || queriedPartitionsCount == sampleSize) {
OptionalDouble averageRowsPerPartitionOptional = Arrays.stream(rowCounts).average();
if (!averageRowsPerPartitionOptional.isPresent()) {
return Optional.empty();
}
double averageRowsPerPartition = averageRowsPerPartitionOptional.getAsDouble();
return Optional.of(new PartitionsRowCount(averageRowsPerPartition, averageRowsPerPartition * queriedPartitionsCount));
}

// Some partitions (e.g. __HIVE_DEFAULT_PARTITION__) may be outliers in terms of row count.
// Excluding the min and max rowCount values from averageRowsPerPartition calculation helps to reduce the
// possibility of errors in the extrapolated rowCount due to a couple of outliers.
int minIndex = 0;
int maxIndex = 0;
long rowCountSum = rowCounts[0];
for (int index = 1; index < rowCounts.length; index++) {
if (rowCounts[index] < rowCounts[minIndex]) {
minIndex = index;
}
else if (rowCounts[index] > rowCounts[maxIndex]) {
maxIndex = index;
}
rowCountSum += rowCounts[index];
}
double averageWithoutOutliers = ((double) (rowCountSum - rowCounts[minIndex] - rowCounts[maxIndex])) / (rowCounts.length - 2);
double rowCount = (averageWithoutOutliers * (queriedPartitionsCount - 2)) + rowCounts[minIndex] + rowCounts[maxIndex];
return Optional.of(new PartitionsRowCount(averageWithoutOutliers, rowCount));
}

static long calulateFileCount(Collection<PartitionStatistics> statistics)
Expand Down Expand Up @@ -932,4 +962,58 @@ interface PartitionsStatisticsProvider
{
Map<String, PartitionStatistics> getPartitionsStatistics(ConnectorSession session, SchemaTableName schemaTableName, List<HivePartition> hivePartitions, Table table);
}

@VisibleForTesting
static class PartitionsRowCount
{
private final double averageRowsPerPartition;
private final double rowCount;

PartitionsRowCount(double averageRowsPerPartition, double rowCount)
{
verify(averageRowsPerPartition >= 0, "averageRowsPerPartition must be greater than or equal to zero");
verify(rowCount >= 0, "rowCount must be greater than or equal to zero");
this.averageRowsPerPartition = averageRowsPerPartition;
this.rowCount = rowCount;
}

private double getAverageRowsPerPartition()
{
return averageRowsPerPartition;
}

private double getRowCount()
{
return rowCount;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PartitionsRowCount that = (PartitionsRowCount) o;
return Double.compare(that.averageRowsPerPartition, averageRowsPerPartition) == 0
&& Double.compare(that.rowCount, rowCount) == 0;
}

@Override
public int hashCode()
{
return Objects.hash(averageRowsPerPartition, rowCount);
}

@Override
public String toString()
{
return toStringHelper(this)
.add("averageRowsPerPartition", averageRowsPerPartition)
.add("rowCount", rowCount)
.toString();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@
import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY;
import static io.prestosql.execution.SqlStageExecution.createSqlStageExecution;
import static io.prestosql.execution.scheduler.TestPhasedExecutionSchedule.createTableScanPlanFragment;
import static io.prestosql.execution.scheduler.TestSourcePartitionedScheduler.createFixedSplitSource;
import static io.prestosql.execution.scheduler.policy.TestPhasedExecutionSchedule.createTableScanPlanFragment;
import static io.prestosql.plugin.hive.HiveColumnHandle.BUCKET_COLUMN_NAME;
import static io.prestosql.plugin.hive.HiveColumnHandle.PATH_COLUMN_NAME;
import static io.prestosql.plugin.hive.HiveCompressionCodec.NONE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@
import static io.prestosql.plugin.hive.HiveType.HIVE_LONG;
import static io.prestosql.plugin.hive.HiveType.HIVE_STRING;
import static io.prestosql.plugin.hive.HiveUtil.parsePartitionValue;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateAverageRowsPerPartition;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.PartitionsRowCount;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSize;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDataSizeForPartitioningKey;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctPartitionKeys;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateDistinctValuesCount;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFraction;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateNullsFractionForPartitioningKey;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculatePartitionsRowCount;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRange;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.calculateRangeForPartitioningKey;
import static io.prestosql.plugin.hive.statistics.MetastoreHiveStatisticsProvider.convertPartitionValueToDouble;
Expand All @@ -82,6 +83,7 @@
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static java.lang.Double.NaN;
import static java.lang.String.format;
import static java.util.Collections.nCopies;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
Expand Down Expand Up @@ -240,15 +242,34 @@ public void testValidatePartitionStatistics()
}

@Test
public void testCalculateAverageRowsPerPartition()
{
assertThat(calculateAverageRowsPerPartition(ImmutableList.of())).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty()))).isEmpty();
assertThat(calculateAverageRowsPerPartition(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()))).isEmpty();
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10))), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), PartitionStatistics.empty())), OptionalDouble.of(10));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20))), OptionalDouble.of(15));
assertEquals(calculateAverageRowsPerPartition(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty())), OptionalDouble.of(15));
public void testCalculatePartitionsRowCount()
{
assertThat(calculatePartitionsRowCount(ImmutableList.of(), 0)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty()), 1)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(PartitionStatistics.empty(), PartitionStatistics.empty()), 2)).isEmpty();
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 1))
.isEqualTo(Optional.of(new MetastoreHiveStatisticsProvider.PartitionsRowCount(10, 10)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), PartitionStatistics.empty()), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(10, 20)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 2))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 30)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));
assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(20), PartitionStatistics.empty()), 3))
.isEqualTo(Optional.of(new PartitionsRowCount(15, 45)));

assertThat(calculatePartitionsRowCount(ImmutableList.of(rowsCount(10), rowsCount(100), rowsCount(1000)), 3))
.isEqualTo(Optional.of(new PartitionsRowCount((10 + 100 + 1000) / 3.0, 10 + 100 + 1000)));
// Exclude outliers from average row count
assertThat(calculatePartitionsRowCount(ImmutableList.<PartitionStatistics>builder()
.addAll(nCopies(10, rowsCount(100)))
.add(rowsCount(1))
.add(rowsCount(1000))
.build(),
50))
.isEqualTo(Optional.of(new PartitionsRowCount(100, (100 * 48) + 1 + 1000)));
}

@Test
Expand Down
Loading