scala - How to define and use a User-Defined Aggregate Function in Spark SQL? -
i know how write udf in spark sql:
def belowthreshold(power: int): boolean = { return power < -40 } sqlcontext.udf.register("belowthreshold", belowthreshold _)
can similar define aggregate function? how done?
for context, want run following sql query:
val aggdf = sqlcontext.sql("""select span, belowthreshold(opticalreceivepower), timestamp ifdf opticalreceivepower not null group span, timestamp order span""")
it should return like
row(span1, false, t0)
i want aggregate function tell me if there's values opticalreceivepower
in groups defined span
, timestamp
below threshold. need write udaf differently udf pasted above?
supported methods
spark 2.0+ (optionally 1.6+ different api):
it possible use aggregators
on typed datasets
:
import org.apache.spark.sql.expressions.aggregator import org.apache.spark.sql.{encoder, encoders} class belowthreshold[i](f: => boolean) extends aggregator[i, boolean, boolean] serializable { def 0 = false def reduce(acc: boolean, x: i) = acc | f(x) def merge(acc1: boolean, acc2: boolean) = acc1 | acc2 def finish(acc: boolean) = acc def bufferencoder: encoder[boolean] = encoders.scalaboolean def outputencoder: encoder[boolean] = encoders.scalaboolean } val belowthreshold = new belowthreshold[(string, int)](_._2 < - 40).tocolumn df.as[(string, int)].groupbykey(_._1).agg(belowthreshold)
spark >= 1.5:
in spark 1.5 can create udaf although overkill:
import org.apache.spark.sql.expressions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.row object belowthreshold extends userdefinedaggregatefunction { // schema input def inputschema = new structtype().add("power", integertype) // schema of row used aggregation def bufferschema = new structtype().add("ind", booleantype) // returned type def datatype = booleantype // self-explaining def deterministic = true // 0 value def initialize(buffer: mutableaggregationbuffer) = buffer.update(0, false) // similar seqop in aggregate def update(buffer: mutableaggregationbuffer, input: row) = { if (!input.isnullat(0)) buffer.update(0, buffer.getboolean(0) | input.getint(0) < -40) } // similar combop in aggregate def merge(buffer1: mutableaggregationbuffer, buffer2: row) = { buffer1.update(0, buffer1.getboolean(0) | buffer2.getboolean(0)) } // called on exit return value def evaluate(buffer: row) = buffer.getboolean(0) }
example usage:
df .groupby($"group") .agg(belowthreshold($"power").alias("belowthreshold")) .show // +-----+--------------+ // |group|belowthreshold| // +-----+--------------+ // | a| false| // | b| true| // +-----+--------------+
spark 1.4 workaround:
i not sure if correctly understand requirements far can tell plain old aggregation should enough here:
val df = sc.parallelize(seq( ("a", 0), ("a", 1), ("b", 30), ("b", -50))).todf("group", "power") df .withcolumn("belowthreshold", ($"power".lt(-40)).cast(integertype)) .groupby($"group") .agg(sum($"belowthreshold").notequal(0).alias("belowthreshold")) .show // +-----+--------------+ // |group|belowthreshold| // +-----+--------------+ // | a| false| // | b| true| // +-----+--------------+
spark <= 1.4:
as far know, @ moment (spark 1.4.1), there no support udaf, other hive ones. should possible spark 1.5 (see spark-3947).
unsupported / internal methods
internally spark uses number of classes including imperativeaggregates
, declarativeaggregates
.
there intended internal usage , may change without further notice, not want use in production code, completeness belowthreshold
declarativeaggregate
implemented (tested spark 2.2-snapshot):
import org.apache.spark.sql.catalyst.expressions.aggregate.declarativeaggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ case class belowthreshold(child: expression, threshold: expression) extends declarativeaggregate { override def children: seq[expression] = seq(child, threshold) override def nullable: boolean = false override def datatype: datatype = booleantype private lazy val belowthreshold = attributereference( "belowthreshold", booleantype, nullable = false )() // used derive schema override lazy val aggbufferattributes = belowthreshold :: nil override lazy val initialvalues = seq( literal(false) ) override lazy val updateexpressions = seq(or( belowthreshold, if(isnull(child), literal(false), lessthan(child, threshold)) )) override lazy val mergeexpressions = seq( or(belowthreshold.left, belowthreshold.right) ) override lazy val evaluateexpression = belowthreshold override def defaultresult: option[literal] = option(literal(false)) }
it should further wrapped equivalent of withaggregatefunction
.
Comments
Post a Comment