Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7750][VL]: store unsafe batches data #7902

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,20 @@ import org.apache.gluten.runtime.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.utils.ArrowAbiUtil
import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.utils.SparkArrowUtil
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.task.TaskResources

import org.apache.arrow.c.ArrowSchema
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector
import org.apache.spark.sql.types.DataTypes

import scala.collection.JavaConverters.asScalaIteratorConverter

case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Array[Byte]])
case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: UnsafeArray)
extends BuildSideRelation {

override def deserialized: Iterator[ColumnarBatch] = {
Expand All @@ -60,15 +60,16 @@ case class ColumnarBuildSideRelation(output: Seq[Attribute], batches: Array[Arra
var batchId = 0

override def hasNext: Boolean = {
batchId < batches.length
batchId < batches.getLength
}

override def next: ColumnarBatch = {
val handle =
jniWrapper
.deserialize(serializeHandle, batches(batchId))
val batch = batches.get(batchId)
val columnVector = new OffHeapColumnVector(batch.numElements(), DataTypes.BinaryType)
columnVector.putByteArray(batchId, batch.toByteArray, batch.getBaseOffset.toInt, batch.numElements)
val columnarBatch = new ColumnarBatch(Array(columnVector))
Comment on lines +68 to +70
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhztheplayer Is it alright to create the ColumnarBatch this way, using OffHeapColumnVector and constructing a new ColumnarBatch directly from that?

batchId += 1
ColumnarBatches.create(handle)
columnarBatch
}
})
.protectInvocationFlow()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.apache.spark.sql.execution

import org.apache.spark.memory.{MemoryConsumer, MemoryMode, SparkOutOfMemoryError, TaskMemoryManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
import org.apache.spark.unsafe.{Platform, UnsafeAlignedOffset}
import org.apache.spark.unsafe.memory.MemoryBlock

class UnsafeArray(taskMemoryManager: TaskMemoryManager) extends MemoryConsumer(taskMemoryManager, MemoryMode.OFF_HEAP) {

protected var page: MemoryBlock = null
acquirePage(taskMemoryManager.pageSizeBytes)
protected var base: AnyRef = page.getBaseObject
protected var pageCursor = 0
private var keyOffsets: Array[Long] = null
protected var numRows = 0

def iterator() {}

private def acquirePage(requiredSize: Long): Boolean = {
try page = allocatePage(requiredSize)

catch {
case SparkOutOfMemoryError =>
return false
}
base = page.getBaseObject
pageCursor = 0
true
}

def get(rowId: Int): UnsafeArrayData = {
val offset = keyOffsets(rowId)
val klen = UnsafeAlignedOffset.getSize(base, offset - UnsafeAlignedOffset.getUaoSize)
val result = new UnsafeArrayData
result.pointTo (base, offset, klen)
result
}

def write(bytes: Array[Byte], inputOffset: Long, inputLength: Int): Unit = {
var offset: Long = page.getBaseOffset + pageCursor
val recordOffset = offset

val uaoSize = UnsafeAlignedOffset.getUaoSize

val recordLength = 2L * uaoSize + inputLength + 8L

UnsafeAlignedOffset.putSize(base, offset, inputLength + uaoSize)
offset += 2L * uaoSize
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I'm calculating the offset correctly

Platform.copyMemory(bytes, inputOffset, base, offset, inputLength)
Platform.putLong(base, offset, 0)

pageCursor += recordLength
keyOffsets(numRows) = recordOffset + 2L * uaoSize;
numRows += 1
}

override def spill(l: Long, memoryConsumer: MemoryConsumer): Long = ???
}
Loading