sqlserver支持for xml path()语法,将返回结果嵌套在指定的xml标签中。项目组之前在spark2.0上实现了该功能。迁移到2.3时,由于原生spark修改较多,出现了很大的兼容问题。我的工作就是让这个函数重新运作起来。菜鸟真的被折磨的很痛苦,所幸还是成功解决了问题。
1. 语法说明
关于sqlserver中 for xml path的语法,大致就是将指定字段和连接的字符串包裹在xml标签中并返回,支持指定节点名。项目组在spark中的实现叫做group_xmlpath(),暂不支持指定标签名。
@ExpressionDescription(
usage = "_FUNC_(expr) - Concat a list of elements.in a group.")
2. 函数实现
这个函底层有聚集实现。因此是在Collect.scala中实现。仿照 Collect_list 进行实现,spark2.3对上层接口进行了重构,增加了TypedImperativeAggregate,将很多方法都定义为final,使得之前的自定义代码都无法使用。
因此,为了可以定制实现,将final方法都放出来,以便重载。
TypedImperativeAggregate对聚集的工作流程进行了定义,大致有三个步骤:初始化,处理和返回结果。对应调用 方法是initialize ,update、merge和eval
* General work flow:
*
* Stage 1: initialize aggregate buffer object.
*
* 1. The framework calls `initialize(buffer: MutableRow)` to set up the empty aggregate buffer.
* 2. In `initialize`, we call `createAggregationBuffer(): T` to get the initial buffer object,
* and set it to the global buffer row.
*
*
* Stage 2: process input rows.
*
* If the aggregate mode is `Partial` or `Complete`:
* 1. The framework calls `update(buffer: MutableRow, input: InternalRow)` to process the input
* row.
* 2. In `update`, we get the buffer object from the global buffer row and call
* `update(buffer: T, input: InternalRow): Unit`.
*
* If the aggregate mode is `PartialMerge` or `Final`:
* 1. The framework call `merge(buffer: MutableRow, inputBuffer: InternalRow)` to process the
* input row, which are serialized buffer objects shuffled from other nodes.
* 2. In `merge`, we get the buffer object from the global buffer row, and get the binary data
* from input row and deserialize it to buffer object, then we call
* `merge(buffer: T, input: T): Unit` to merge these 2 buffer objects.
*
*
* Stage 3: output results.
*
* If the aggregate mode is `Partial` or `PartialMerge`:
* 1. The framework calls `serializeAggregateBufferInPlace` to replace the buffer object in the
* global buffer row with binary data.
* 2. In `serializeAggregateBufferInPlace`, we get the buffer object from the global buffer row
* and call `serialize(buffer: T): Array[Byte]` to serialize the buffer object to binary.
* 3. The framework outputs buffer attributes and shuffle them to other nodes.
*
* If the aggregate mode is `Final` or `Complete`:
* 1. The framework calls `eval(buffer: InternalRow)` to calculate the final result.
* 2. In `eval`, we get the buffer object from the global buffer row and call
* `eval(buffer: T): Any` to get the final result.
* 3. The framework outputs these final results.
*
*
* Window function work flow:
* The framework calls `update(buffer: MutableRow, input: InternalRow)` several times and then
* call `eval(buffer: InternalRow)`, so there is no need for window operator to call
* `serializeAggregateBufferInPlace`.
*
*
* NOTE: SQL with TypedImperativeAggregate functions is planned in sort based aggregation,
* instead of hash based aggregation, as TypedImperativeAggregate use BinaryType as aggregation
* buffer's storage format, which is not supported by hash based aggregation. Hash based
* aggregation only support aggregation buffer of mutable types (like LongType, IntType that have
* fixed length and can be mutated in place in UnsafeRow).
* NOTE: The newly added ObjectHashAggregateExec supports TypedImperativeAggregate functions in
* hash based aggregation under some constraints.
*/
框架会维护一个全局缓冲区,这是一个巨大坑。
2.1原始代码
/**
* Concat a list of elements.in a group.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Concat a list of elements.in a group.")
case class CollectGroupXMLPath(
cols: Seq[Expression],
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect { def this(cols: Seq[Expression]) = this(cols, 0, 0)
override val child = null
override def children: Seq[Expression] = cols
override def nullable: Boolean = true
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
override def aggBufferAttributes: Seq[AttributeReference] = super.aggBufferAttributes
override def checkInputDataTypes(): TypeCheckResult = {
val allOK = cols.forall(child =>
!child.dataType.existsRecursively(_.isInstanceOf[MapType]))
if (allOK) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("group_xmlpath() cannot have map type data")
}
} override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset) override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset) override def prettyName: String = "group_xmlpath"
override protected[this] val buffer: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
protected[this] val strbuffer: StringBuilder = new StringBuilder
private var columnNames: Seq[(String, String)] = Seq.fill(cols.length)(("", ""))
private var rootSpec = ("<root>", "</root>")
private var rowSpec = ("<row>", "</row>") override def initialize(b: InternalRow): Unit = {
buffer.clear() strbuffer.clear()
initializeColNames
strbuffer.append(rootSpec._1)
}
private def initializeColNames = {
cols.last match {
case Literal(v, d) if d.isInstanceOf[ArrayType] =>
val av = v.asInstanceOf[GenericArrayData]
val names = av.array.map( _.toString.trim )
val namepair = names.map(e => if ( e.length > 0 ) (s"<$e>", s"</$e>") else ("", "")).toSeq
rootSpec = namepair(0)
rowSpec = namepair(1)
columnNames = namepair.slice(2, namepair.length)
case _ =>
}
} override def update(b: InternalRow, input: InternalRow): Unit = {
strbuffer.append(rowSpec._1)
for( i <- 0 to ( cols.length - 2) ) {
strbuffer.append(columnNames(i)._1)
.append(cols(i).eval(input))
.append(columnNames(i)._2)
}
strbuffer.append(rowSpec._2)
} override def merge(buffer: InternalRow, input: InternalRow): Unit = {
sys.error("group_xmlpath cannot be used in partial aggregations.")
} override def eval(input: InternalRow): Any = {
strbuffer.append(rootSpec._2)
UTF8String.fromString(strbuffer.toString())
}
}
之前代码用不了的原因就在于update和merge都会初始化缓冲区,即调用initialize方法。前一个版本的缓冲区是一个本地缓冲区,它的初始化都写在initialize方法中,因此后面的merge和update过程会清空本地缓冲区。尝试修改源码让直行流程不走merge阶段,但是会造成eval不调用,结果出错。在仔细研究接口注释后。决定从全局的缓冲区入手。
2.2 修改后代码
最后得到的代码如下:
case class CollectGroupXMLPath(
cols: Seq[Expression],
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0) extends Collect[mutable.ArrayBuffer[Any]] {
def this(cols: Seq[Expression]) = this(cols, 0, 0)
override val child = cols.head
override def children: Seq[Expression] = cols
override def nullable: Boolean = true
override def dataType: DataType = StringType
//override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
// override def aggBufferAttributes: Seq[AttributeReference] = super.aggBufferAttributes
override def checkInputDataTypes(): TypeCheckResult = {
val allOK = cols.forall(child =>
!child.dataType.existsRecursively(_.isInstanceOf[MapType]))
if (allOK) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure("group_xmlpath() cannot have map type data")
}
}
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
override def prettyName: String = "group_xmlpath"
protected[this] val strbuffer: StringBuilder = new StringBuilder
private var columnNames: Seq[(String, String)] = Seq.fill(cols.length)(("", ""))
private var rootSpec = ("<root>", "</root>")
private var rowSpec = ("<row>", "</row>")
private[this] val anyObjectType = ObjectType(classOf[AnyRef])
/* override def initialize(b: InternalRow): Unit = {
buffer.clear()
strbuffer.clear()
initializeColNames
strbuffer.append(rootSpec._1)
createAggregationBuffer()
}*/
private def initializeColNames = {
cols.last match {
case Literal(v, d) if d.isInstanceOf[ArrayType] =>
val av = v.asInstanceOf[GenericArrayData]
val names = av.array.map(_.toString.trim)
val namepair = names.map(e => if (e.length > 0) (s"<$e>", s"</$e>") else ("", "")).toSeq
rootSpec = namepair(0)
rowSpec = namepair(1)
columnNames = namepair.slice(2, namepair.length)
case _ =>
}
}
override def update(b: InternalRow, input: InternalRow): Unit = {
// Note: remember to clear local buffer first to avoid redundant data
strbuffer.clear()
strbuffer.append(rowSpec._1)
for (i <- 0 to (cols.length - 2)) {
strbuffer.append(columnNames(i)._1)
.append(cols(i).eval(input))
.append(columnNames(i)._2)
}
strbuffer.append(rowSpec._2)
val out = InternalRow.fromSeq(Array(UTF8String.fromString(strbuffer.toString())))
// force to merge input buffer into global buffer
b(mutableAggBufferOffset) = getBufferObject(b) += out
}
private def getBufferObject(bufferRow: InternalRow): ArrayBuffer[Any] = {
bufferRow.get(mutableAggBufferOffset, anyObjectType).asInstanceOf[ArrayBuffer[Any]]
}
override def merge(buffer: InternalRow, input: InternalRow): Unit = {
super.merge(buffer, input)
}
override def eval(input: InternalRow): Any = {
val head = input.toSeq(Seq(StringType)).head
var buff = ArrayBuffer[UTF8String]()
if (head.isInstanceOf[ArrayBuffer[UTF8String]])
buff = head.asInstanceOf[ArrayBuffer[UTF8String]]
val out = new mutable.StringBuilder()
// reformat the out put
out.append(rootSpec._1)
val tmp = new mutable.StringBuilder()
for (i <- 0 until buff.length) {
out.append(tmp).append(buff(i)).toString()
tmp.clear()
}
out.append(rootSpec._2)
UTF8String.fromString(out.toString())
}
private lazy val projection = UnsafeProjection.create(
Array[DataType](ArrayType(elementType = child.dataType, containsNull = false)))
override def serialize(obj: mutable.ArrayBuffer[Any]): Array[Byte] = {
val array = new GenericArrayData(Array(UTF8String.fromString(strbuffer.toString())))
val bytes = projection.apply(InternalRow.apply(array)).getBytes()
bytes
}
override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty
}
3. 函数注册
最后,需要对函数进行注册,在FunctionRegistryd object 中添加一行
expression[CollectGroupXMLPath]("group_xmlpath"),
最后说几句题外话,有关sparksql调试的一些小技巧:
1. 查看物理计划。在sql前面加explain,打印物理计划做分析
2. 多println中间结果查看调用
3.异常调试法。搞不清调用关系的地方,可以抛个异常查看调用栈
4.断点。由于spark这种并行计算框架,断点时间过长会产生丢失心跳等。有时候不好用