diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index bed241b1e03a9..f053135c4dbd7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -90,9 +90,9 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck private final int numElementsForSpillThreshold; /** - * Force this sorter to spill when the size in memory is beyond this threshold. + * Force this sorter to spill when the in memory size in bytes is beyond this threshold. */ - private final long recordsSizeForSpillThreshold; + private final long sizeInBytesForSpillThreshold; /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -117,7 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; - private long inMemRecordsSize = 0; + private long totalPageMemoryUsageBytes = 0; // Checksum calculator for each partition. Empty when shuffle checksum disabled. private final Checksum[] partitionChecksums; @@ -142,7 +142,7 @@ final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleCheck (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); - this.recordsSizeForSpillThreshold = + this.sizeInBytesForSpillThreshold = (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( @@ -314,11 +314,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } private long getMemoryUsage() { - long totalPageSize = 0; - for (MemoryBlock page : allocatedPages) { - totalPageSize += page.size(); - } - return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageMemoryUsageBytes; } private void updatePeakMemoryUsed() { @@ -342,11 +338,11 @@ private long freeMemory() { for (MemoryBlock block : allocatedPages) { memoryFreed += block.size(); freePage(block); + totalPageMemoryUsageBytes -= block.size(); } allocatedPages.clear(); currentPage = null; pageCursor = 0; - inMemRecordsSize = 0; return memoryFreed; } @@ -417,6 +413,7 @@ private void acquireNewPageIfNecessary(int required) { currentPage = allocatePage(required); pageCursor = currentPage.getBaseOffset(); allocatedPages.add(currentPage); + totalPageMemoryUsageBytes += currentPage.size(); } } @@ -432,10 +429,17 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()), MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, numElementsForSpillThreshold)); spill(); - } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { - logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}", - MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize), - MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, recordsSizeForSpillThreshold)); + } + + // TODO: Ideally we only need to check the spill threshold when new memory needs to be + // allocated (both this sorter and the underlying ShuffleInMemorySorter may allocate + // new memory), but it's simpler to check the total memory usage of these two sorters + // before inserting each record. + final long usedMemory = getMemoryUsage(); + if (usedMemory >= sizeInBytesForSpillThreshold) { + logger.info("Spilling data because memory usage ({}) crossed the threshold {}", + MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory), + MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, sizeInBytesForSpillThreshold)); spill(); } @@ -453,7 +457,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); - inMemRecordsSize += required; } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 94c37e187131f..71a826642b1be 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -80,9 +80,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final int numElementsForSpillThreshold; /** - * Force this sorter to spill when the size in memory is beyond this threshold. + * Force this sorter to spill when the in memory size in bytes is beyond this threshold. */ - private final long recordsSizeForSpillThreshold; + private final long sizeInBytesForSpillThreshold; /** * Memory pages that hold the records being sorted. The pages in this list are freed when @@ -96,7 +96,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { // These variables are reset after spilling: @Nullable private volatile UnsafeInMemorySorter inMemSorter; - private long inMemRecordsSize = 0; + private long totalPageMemoryUsageBytes = 0; private MemoryBlock currentPage = null; private long pageCursor = -1; @@ -115,12 +115,12 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, - long recordsSizeForSpillThreshold, + long sizeInBytesForSpillThreshold, UnsafeInMemorySorter inMemorySorter, long existingMemoryConsumption) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - pageSizeBytes, numElementsForSpillThreshold, recordsSizeForSpillThreshold, + pageSizeBytes, numElementsForSpillThreshold, sizeInBytesForSpillThreshold, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); taskContext.taskMetrics().incMemoryBytesSpilled(existingMemoryConsumption); @@ -140,11 +140,11 @@ public static UnsafeExternalSorter create( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, - long recordsSizeForSpillThreshold, + long sizeInBytesForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, - numElementsForSpillThreshold, recordsSizeForSpillThreshold, null, canUseRadixSort); + numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null, canUseRadixSort); } private UnsafeExternalSorter( @@ -157,7 +157,7 @@ private UnsafeExternalSorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, - long recordsSizeForSpillThreshold, + long sizeInBytesForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -187,7 +187,7 @@ private UnsafeExternalSorter( this.inMemSorter = existingInMemorySorter; } this.peakMemoryUsedBytes = getMemoryUsage(); - this.recordsSizeForSpillThreshold = recordsSizeForSpillThreshold; + this.sizeInBytesForSpillThreshold = sizeInBytesForSpillThreshold; this.numElementsForSpillThreshold = numElementsForSpillThreshold; // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at @@ -248,7 +248,6 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. inMemSorter.freeMemory(); - inMemRecordsSize = 0; // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory // pages, we might not be able to get memory for the pointer array. @@ -264,11 +263,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { * array. */ private long getMemoryUsage() { - long totalPageSize = 0; - for (MemoryBlock page : allocatedPages) { - totalPageSize += page.size(); - } - return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize; + return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageMemoryUsageBytes; } private void updatePeakMemoryUsed() { @@ -320,6 +315,7 @@ private long freeMemory() { for (MemoryBlock block : pagesToFree) { memoryFreed += block.size(); freePage(block); + totalPageMemoryUsageBytes -= block.size(); } return memoryFreed; } @@ -378,6 +374,7 @@ public void cleanupResources() { } finally { for (MemoryBlock pageToFree : pagesToFree) { freePage(pageToFree); + totalPageMemoryUsageBytes -= pageToFree.size(); } if (inMemSorterToFree != null) { inMemSorterToFree.freeMemory(); @@ -448,6 +445,7 @@ private void acquireNewPageIfNecessary(int required) { currentPage = allocatePage(required); pageCursor = currentPage.getBaseOffset(); allocatedPages.add(currentPage); + totalPageMemoryUsageBytes += currentPage.size(); } } @@ -495,10 +493,17 @@ public void insertRecord( MDC.of(LogKeys.NUM_ELEMENTS_SPILL_RECORDS, inMemSorter.numRecords()), MDC.of(LogKeys.NUM_ELEMENTS_SPILL_THRESHOLD, numElementsForSpillThreshold)); spill(); - } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { - logger.info("Spilling data because size of spilledRecords ({}) crossed the size threshold {}", - MDC.of(LogKeys.SPILL_RECORDS_SIZE, inMemRecordsSize), - MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, recordsSizeForSpillThreshold)); + } + + // TODO: Ideally we only need to check the spill threshold when new memory needs to be + // allocated (both this sorter and the underlying UnsafeInMemorySorter may allocate + // new memory), but it's simpler to check the total memory usage of these two sorters + // before inserting each record. + final long usedMemory = getMemoryUsage(); + if (usedMemory >= sizeInBytesForSpillThreshold) { + logger.info("Spilling data because memory usage ({}) crossed the threshold {}", + MDC.of(LogKeys.SPILL_RECORDS_SIZE, usedMemory), + MDC.of(LogKeys.SPILL_RECORDS_SIZE_THRESHOLD, sizeInBytesForSpillThreshold)); spill(); } @@ -514,7 +519,6 @@ public void insertRecord( Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); - inMemRecordsSize += required; } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0bee708bca3c7..120948064f92b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1599,9 +1599,9 @@ package object config { .createWithDefault(Integer.MAX_VALUE) private[spark] val SHUFFLE_SPILL_MAX_SIZE_FORCE_SPILL_THRESHOLD = - ConfigBuilder("spark.shuffle.spill.maxRecordsSizeForSpillThreshold") + ConfigBuilder("spark.shuffle.spill.maxSizeInBytesForSpillThreshold") .internal() - .doc("The maximum size in memory before forcing the shuffle sorter to spill. " + + .doc("The maximum in memory size in bytes before forcing the shuffle sorter to spill. " + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + "until we reach some limitations, like the max page size limitation for the pointer " + "array in the sorter.") diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 6affcb61b8d69..ca49c5f306ca6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -61,9 +61,9 @@ public UnsafeKVExternalSorter( SerializerManager serializerManager, long pageSizeBytes, int numElementsForSpillThreshold, - long maxRecordsSizeForSpillThreshold) throws IOException { + long sizeInBytesForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null); + numElementsForSpillThreshold, sizeInBytesForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -73,7 +73,7 @@ public UnsafeKVExternalSorter( SerializerManager serializerManager, long pageSizeBytes, int numElementsForSpillThreshold, - long maxRecordsSizeForSpillThreshold, + long sizeInBytesForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -100,7 +100,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, - maxRecordsSizeForSpillThreshold, + sizeInBytesForSpillThreshold, canUseRadixSort); } else { // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow @@ -168,7 +168,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, - maxRecordsSizeForSpillThreshold, + sizeInBytesForSpillThreshold, inMemSorter, map.getTotalMemoryConsumption()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index 3e98c28b29fbc..e8d5858b04fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.internal.LogKeys.{CLASS_NAME, MAX_NUM_ROWS_IN_MEMORY_BUFFER} +import org.apache.spark.internal.LogKeys.{CLASS_NAME, MAX_NUM_ROWS_IN_MEMORY_BUFFER, NUM_BYTES_MAX} import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -34,17 +34,18 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array - * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which - * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is - * excessive memory consumption). Setting these threshold involves following trade-offs: + * until [[numRowsInMemoryBufferThreshold]] or [[sizeInBytesInMemoryBufferThreshold]] is reached + * post which it will switch to a mode (backed by [[UnsafeExternalSorter]]) which would flush to + * disk after [[numRowsSpillThreshold]] or [[sizeInBytesSpillThreshold]] is met (or before if there + * is excessive memory consumption). Setting these threshold involves following trade-offs: * - * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory - * than is available, resulting in OOM. - * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to - * excessive disk writes. This may lead to a performance regression compared to the normal case - * of using an [[ArrayBuffer]] or [[Array]]. + * - If [[numRowsInMemoryBufferThreshold]] and [[sizeInBytesInMemoryBufferThreshold]] are too high, + * the in-memory array may occupy more memory than is available, resulting in OOM. + * - If [[numRowsSpillThreshold]] or [[sizeInBytesSpillThreshold]] is too low, data will be spilled + * frequently and lead to excessive disk writes. This may lead to a performance regression + * compared to the normal case of using an [[ArrayBuffer]] or [[Array]]. */ -private[sql] class ExternalAppendOnlyUnsafeRowArray( +class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, blockManager: BlockManager, serializerManager: SerializerManager, @@ -52,12 +53,15 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, + sizeInBytesInMemoryBufferThreshold: Long, numRowsSpillThreshold: Int, - maxSizeSpillThreshold: Long) extends Logging { + sizeInBytesSpillThreshold: Long) extends Logging { - def this(numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int, - maxSizeSpillThreshold: Long) = { + def this( + numRowsInMemoryBufferThreshold: Int, + sizeInBytesInMemoryBufferThreshold: Long, + numRowsSpillThreshold: Int, + sizeInBytesSpillThreshold: Long) = { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -66,8 +70,9 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, + sizeInBytesInMemoryBufferThreshold, numRowsSpillThreshold, - maxSizeSpillThreshold) + sizeInBytesSpillThreshold) } private val initialSizeOfInMemoryBuffer = @@ -78,6 +83,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( } else { null } + private var inMemoryBufferSizeInBytes = 0L private var spillableArray: UnsafeExternalSorter = _ private var totalSpillBytes: Long = 0 @@ -116,6 +122,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( spillableArray = null } else if (inMemoryBuffer != null) { inMemoryBuffer.clear() + inMemoryBufferSizeInBytes = 0; } numFieldsPerRow = 0 numRows = 0 @@ -123,12 +130,16 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( } def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsInMemoryBufferThreshold) { + // Once spills, we will switch to UnsafeExternalSorter permanently. + if (spillableArray == null && numRows < numRowsInMemoryBufferThreshold && + inMemoryBufferSizeInBytes < sizeInBytesInMemoryBufferThreshold) { inMemoryBuffer += unsafeRow.copy() + inMemoryBufferSizeInBytes += unsafeRow.getSizeInBytes } else { if (spillableArray == null) { logInfo(log"Reached spill threshold of " + log"${MDC(MAX_NUM_ROWS_IN_MEMORY_BUFFER, numRowsInMemoryBufferThreshold)} rows, " + + log"or ${MDC(NUM_BYTES_MAX, sizeInBytesInMemoryBufferThreshold)} bytes, " + log"switching to ${MDC(CLASS_NAME, classOf[UnsafeExternalSorter].getName)}") // We will not sort the rows, so prefixComparator and recordComparator are null @@ -142,7 +153,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize, pageSizeBytes, numRowsSpillThreshold, - maxSizeSpillThreshold, + sizeInBytesSpillThreshold, false) // populate with existing in-memory buffered rows @@ -156,6 +167,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( false) ) inMemoryBuffer.clear() + inMemoryBufferSizeInBytes = 0 } numFieldsPerRow = unsafeRow.numFields() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala index 64bb3717f52bc..cf146889912d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala @@ -44,7 +44,7 @@ class UpdatingSessionsIterator( inputSchema: Seq[Attribute], inMemoryThreshold: Int, spillThreshold: Int, - spillSizeThreshold: Long) extends Iterator[InternalRow] { + sizeInBytesSpillThreshold: Long) extends Iterator[InternalRow] { private val groupingWithoutSession: Seq[NamedExpression] = groupingExpressions.diff(Seq(sessionExpression)) @@ -151,8 +151,13 @@ class UpdatingSessionsIterator( currentKeys = groupingKey.copy() currentSession = sessionStruct.copy() - rowsForCurrentSession = new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, - spillSizeThreshold) + rowsForCurrentSession = new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + sizeInBytesSpillThreshold, + spillThreshold, + sizeInBytesSpillThreshold) rowsForCurrentSession.add(currentRow.asInstanceOf[UnsafeRow]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8065decb0dffe..a8523746f9d2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -37,12 +37,17 @@ class UnsafeCartesianRDD( right : RDD[UnsafeRow], inMemoryBufferThreshold: Int, spillThreshold: Int, - spillSizeThreshold: Long) + sizeInBytesSpillThreshold: Long) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold, - spillSizeThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray( + inMemoryBufferThreshold, + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + sizeInBytesSpillThreshold, + spillThreshold, + sizeInBytesSpillThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala index b4e52ba050b8d..2b6a19dfa8a8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala @@ -34,7 +34,7 @@ class SortMergeJoinEvaluatorFactory( output: Seq[Attribute], inMemoryThreshold: Int, spillThreshold: Int, - spillSizeThreshold: Long, + sizeInBytesSpillThreshold: Long, numOutputRows: SQLMetric, spillSize: SQLMetric, onlyBufferFirstMatchedRow: Boolean) @@ -86,7 +86,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources) private[this] val joinRow = new JoinedRow @@ -132,7 +132,7 @@ class SortMergeJoinEvaluatorFactory( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources) val rightNullRow = new GenericInternalRow(right.output.length) @@ -152,7 +152,7 @@ class SortMergeJoinEvaluatorFactory( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources) val leftNullRow = new GenericInternalRow(left.output.length) @@ -189,7 +189,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) @@ -227,7 +227,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) @@ -272,7 +272,7 @@ class SortMergeJoinEvaluatorFactory( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, spillSize, cleanupResources, onlyBufferFirstMatchedRow) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 39387ebbb7ee3..bc2f9197df9df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -103,7 +103,7 @@ case class SortMergeJoinExec( conf.sortMergeJoinExecBufferSpillThreshold } - private def getSpillSizeThreshold: Long = { + private def getSizeInBytesSpillThreshold: Long = { conf.sortMergeJoinExecBufferSpillSizeThreshold } @@ -125,7 +125,7 @@ case class SortMergeJoinExec( val numOutputRows = longMetric("numOutputRows") val spillSize = longMetric("spillSize") val spillThreshold = getSpillThreshold - val spillSizeThreshold = getSpillSizeThreshold + val sizeInBytesSpillThreshold = getSizeInBytesSpillThreshold val inMemoryThreshold = getInMemoryThreshold val evaluatorFactory = new SortMergeJoinEvaluatorFactory( leftKeys, @@ -137,7 +137,7 @@ case class SortMergeJoinExec( output, inMemoryThreshold, spillThreshold, - spillSizeThreshold, + sizeInBytesSpillThreshold, numOutputRows, spillSize, onlyBufferFirstMatchedRow @@ -228,12 +228,15 @@ case class SortMergeJoinExec( val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold - val spillSizeThreshold = getSpillSizeThreshold + val sizeInBytesSpillThreshold = getSizeInBytesSpillThreshold val inMemoryThreshold = getInMemoryThreshold // Inline mutable state since not many join operations in a task val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold, ${spillSizeThreshold}L);", + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + v => s"$v = new $clsName($inMemoryThreshold, ${sizeInBytesSpillThreshold}L, " + + s"$spillThreshold, ${sizeInBytesSpillThreshold}L);", forceInline = true) // Copy the streamed keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, streamedKeyVars) @@ -1052,7 +1055,8 @@ case class SortMergeJoinExec( * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer - * @param spillSizeThreshold Threshold for size of rows to be spilled by internal buffer + * @param sizeInBytesSpillThreshold Threshold for size in bytes of rows to be spilled by + * internal buffer * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the first matching row */ @@ -1064,7 +1068,7 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, - spillSizeThreshold: Long, + sizeInBytesSpillThreshold: Long, spillSize: SQLMetric, eagerCleanupResources: () => Unit, onlyBufferFirstMatch: Boolean = false) { @@ -1079,7 +1083,13 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, spillSizeThreshold) + new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + sizeInBytesSpillThreshold, + spillThreshold, + sizeInBytesSpillThreshold) // At the end of the task, update the task's spill size for buffered side. TaskContext.get().addTaskCompletionListener[Unit](_ => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala index d6cc350e485a8..a92679054dd81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowAggregatePythonExec.scala @@ -215,10 +215,10 @@ case class ArrowAggregatePythonExec( case Some(sessionExpression) => val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold val spillThreshold = conf.windowExecBufferSpillThreshold - val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold + val sizeInBytesSpillThreshold = conf.windowExecBufferSpillSizeThreshold new UpdatingSessionsIterator(iter, groupingWithoutSessionExpressions, sessionExpression, - child.output, inMemoryThreshold, spillThreshold, spillSizeThreshold) + child.output, inMemoryThreshold, spillThreshold, sizeInBytesSpillThreshold) case None => iter } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala index 92ed9ff9de456..1643a8d3bdb1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowWindowPythonEvaluatorFactory.scala @@ -149,7 +149,7 @@ class ArrowWindowPythonEvaluatorFactory( private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold private val spillThreshold = conf.windowExecBufferSpillThreshold - private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold + private val sizeInBytesSpillThreshold = conf.windowExecBufferSpillSizeThreshold private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes @@ -288,8 +288,13 @@ class ArrowWindowPythonEvaluatorFactory( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, - spillSizeThreshold) + new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + sizeInBytesSpillThreshold, + spillThreshold, + sizeInBytesSpillThreshold) var bufferIterator: Iterator[UnsafeRow] = _ val indexRow = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala index d59a0e9f4639b..c4b20d4b7c7d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowEvaluatorFactory.scala @@ -45,7 +45,7 @@ class WindowEvaluatorFactory( private val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray private val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold private val spillThreshold = conf.windowExecBufferSpillThreshold - private val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold + private val sizeInBytesSpillThreshold = conf.windowExecBufferSpillSizeThreshold override def eval( partitionIndex: Int, @@ -83,8 +83,13 @@ class WindowEvaluatorFactory( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, - spillSizeThreshold) + new ExternalAppendOnlyUnsafeRowArray( + inMemoryThreshold, + // TODO: shall we have a new config to specify the max in-memory buffer size + // of ExternalAppendOnlyUnsafeRowArray? + sizeInBytesSpillThreshold, + spillThreshold, + sizeInBytesSpillThreshold) var bufferIterator: Iterator[UnsafeRow] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 461c899325f44..124e15397ca52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -107,6 +107,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray( ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + Long.MaxValue, numSpillThreshold, Long.MaxValue) @@ -172,7 +173,9 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, + val array = new ExternalAppendOnlyUnsafeRowArray( + numSpillThreshold, + Long.MaxValue, numSpillThreshold, Long.MaxValue) rows.foreach(x => array.add(x)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index 62ea7f2f92597..e667a95269f42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -47,6 +47,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, + Long.MaxValue, spillThreshold, Long.MaxValue) try f(array) finally {