大数据学习之 sparkSql UDF(自定义函数)

本文详细介绍了Apache Spark中自定义函数的使用,包括UDF、UDAF的实现及窗体函数的应用,通过实例展示了如何在Spark SQL和DataFrame API中注册和使用自定义函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

注意:需要引入spark-hive的依赖

 

目录

第一部分:自定义函数(常用的一些窗体函数)

第二部分:自定义聚合函数(弱类型)

第三部分:自定义聚合函数(强类型)


package com.spark.self

import org.apache.spark.sql.SparkSession

object UdfDemo {

  def main(args: Array[String]): Unit = {
    newUdf
    newDfUdf
    chuangTi
  }

  /**
    * 根据年龄大小返回是否成年 成年:true,未成年:false
    */
  def isAdult(age: Int) = {
    if (age < 18) {
      false
    } else {
      true
    }
  }

  /**
    * 根据名字计算姓名的的长度
    */
  def isLength(name: String) = {
    name.length
  }

  def con(name: String, age: Int) = {
    age + ":" + name
  }

  def con1(name: String, age: Int) = {
    name + ":" + age
  }

  /**
    * 新版本(Spark2.x)Spark Sql udf示例
    */
  def newUdf() {
    //spark初始化
    val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))

    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")

    // 注册一张user表
    userDF.createOrReplaceTempView("user")

    //注册自定义函数(通过匿名函数)
    spark.udf.register("strLen", (str: String) => str.length())
    spark.udf.register("NameAddAge", (x: String, y: Int) => x + ":" + y)

    //注册自定义函数(通过实名函数)
    spark.udf.register("isAdult", isAdult _)
    spark.udf.register("isLength", isLength _)
    spark.udf.register("con", con(_: String, _: Int))
    spark.sql("select *," +
      "isLength(name) as name_len_shi," +
      "strLen(name) as name_len_ni," +
      "NameAddAge(name,age) as NameAge," +
      "con(name,age) as AgeName," +
      "isAdult(age) as isAdult from user").show

    //关闭
    spark.stop()

  }

  /**
    * 新版本(Spark2.x)DataFrame udf示例
    */
  def newDfUdf() {
    val spark = SparkSession.builder().appName("newDfUdf").master("local[*]").getOrCreate()

    // 构造测试数据,有两个字段、名字和年龄
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))

    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    import org.apache.spark.sql.functions._
    //注册自定义函数(通过匿名函数)
    val strLen = udf((str: String) => str.length())
    val ageAddName = udf((str: String, str1: Int) => str1 + ":" + str)
    //注册自定义函数(通过实名函数)
    val udf_isAdult = udf(isAdult _)
    val udf_nameAddAge = udf(con1(_: String, _: Int))

    //通过withColumn添加列
    userDF.withColumn("name_len", strLen(col("name")))
      .withColumn("isAdult", udf_isAdult(col("age")))
      .withColumn("ageAddName", ageAddName(col("name"), col("age")))
      .withColumn("nameAddAge", udf_nameAddAge(col("name"), col("age")))
      .show
    //通过select添加列
    userDF.select(col("*"),
      strLen(col("name")) as "name_len",
      udf_isAdult(col("age")) as "isAdult",
      ageAddName(col("name"), col("age")) as "ageAddName",
      udf_nameAddAge(col("name"), col("age")) as "nameAddAge"
    ).show

    //关闭
    spark.stop()
  }

  /**
    * 窗体函数
    */
  def chuangTi(): Unit = {
    //窗体函数必须加入hive 支持.enableHiveSupport()
    val spark = SparkSession.builder().appName("newDfUdf").master("local[*]").enableHiveSupport().getOrCreate()


    val scoreData = Array(
      ("a", 1, 80),
      ("b", 1, 78),
      ("c", 1, 95),
      ("d", 2, 74),
      ("e", 2, 92),
      ("f", 3, 99),
      ("g", 3, 99),
      ("h", 3, 45),
      ("i", 3, 55),
      ("j", 3, 78))
    val scoreDF = spark.createDataFrame(scoreData).toDF("name", "class", "score")
    scoreDF.createOrReplaceTempView("score")
    //求每个班最高成绩学生的信息
    /**
      * 1,2,2,4,5,6.。。。。这是rank()的形式
      *
      * 1,2,2,3,4,5,。。。。这是dense_rank()的形式
      *
      * 1,2,3,4,5,6.。。。。。这是row_number()涵数形式
      */
    //    //方式 一
    spark.sql("select name,class,score,rank() over(partition by class order by score desc) rank from score").show()
    spark.sql("select name,class,score,dense_rank() over(partition by class order by score desc) rank from score").show()
    spark.sql("select name,class,score,row_number() over(partition by class order by score desc) rank from score").show()
    //方式二 因为只根据class 进行group by  所以只能展示一个列
    spark.sql("select class, max(score) max from score group by class").show()
    //想要展示多个列使用以下方式
    spark.sql("select a.name, b.class, b.max from score a, " +
      "(select class, max(score) max from score group by class) as b " +
      "where a.score = b.max").show()

    //求班级的总分
    spark.sql("select class,sum(score) over(partition by class order by score desc) sumNumber from score").show()
    spark.sql("select class,sum(score) totalScore from score group by class").show()

    //求每个班级的平均分
    spark.sql("select class,avg(score) avgNumber from score group by class").show()

  }
}

 

 

第二部分:自定义聚合函数(弱类型)

package com.spark.self

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    //spark初始化
    val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()
    val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
    //创建测试df
    val userDF = spark.createDataFrame(userData).toDF("name", "age")
    // 注册一张user表
    userDF.createOrReplaceTempView("user")
    //创建聚合函数对象
    val udaf = new MyAgeAvgFunction
    //将聚合函数注册到spark中
    spark.udf.register("ageAvg", udaf)
    spark.sql("select ageAvg(age) as avg from user").show()
  }
}

//作用自定义聚合函数
//集成UserDefinedAggregateFunction类
class MyAgeAvgFunction extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = {
    new StructType().add("age", LongType)
  }

  //计算时的数据结构 比如平均值需要用到总数sum 和 个数 count
  override def bufferSchema: StructType = {
    new StructType().add("sum", LongType).add("count", LongType)
  }

  //函数返回的类型
  override def dataType: DataType = DoubleType

  //函数是否稳定
  override def deterministic: Boolean = true

  //计算之前缓冲区的初始化(sum 和count的初始化)
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L //0 代表sum 1代表count
    buffer(1) = 0L

  }

  //更新查询结果更新缓冲区的数据
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getLong(0) + input.getLong(0)
    buffer(1) = buffer.getLong(1) + 1
  }

  //将多个节点的缓冲区合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //sum
    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    //count
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  //计算
  override def evaluate(buffer: Row): Any = {
    (buffer.getLong(0) / buffer.getLong(1)).toDouble
  }
}

 

 

第三部分:自定义聚合函数(强类型)

package com.spark.self

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
import org.apache.spark.sql.SparkSession

object UDAFClassDemo {
  def main(args: Array[String]): Unit = {
    //spark初始化
    val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()
    import spark.implicits._
    //创建聚合函数对象
    val udaf = new MyAgeAvgClassFunction
    //将聚合函数转换成查询列
    val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avgAge")
    //    val userDF: DataFrame = spark.read.json("in/user.json")
    //    val userDF: DataFrame = spark.read.load("in/user.json")//load 读取Parquet文件
    val userDF: DataFrame = spark.read.format("json").load("in/user.json")
    //如果想使用load 读取json文件就需要使用format定义文件格式为json文件就可以了
    //load 读取Parquet文件
    val userDs: Dataset[UserBean] = userDF.as[UserBean]
    //应用函数
    userDs.select(avgCol).show()
    //把数据保存到相应的路劲下默认保存为Parquet文件
    //    userDs.write.save("out/")

    
    //如果想保存为json文件就需要定义文件类型
//    userDs.write.format("json").save("out/")

    //使用mode("append")定义保存数据的模式,就不需要删除文件直接执行,追加
    //append 在老的文件基础上追加新的文件 ,overwite 用新的文件覆盖老的
//    userDs.write.format("json").mode("append").save("out/")
    userDs.write.format("json").mode("overwrite").save("out/")

    //释放资源
    spark.stop()
  }
}

//作用自定义聚合函数(强类型)
//集成Aggregator类,设定泛型
//定义输入的参数
case class UserBean(name: String, age: Long)

//定义缓冲区的参数
case class avgBuffer(var sum: Long, var count: Long)

class MyAgeAvgClassFunction extends Aggregator[UserBean, avgBuffer, Double] {
  //初始化
  override def zero: avgBuffer = {
    avgBuffer(0, 0)
  }

  //聚合数据
  override def reduce(b: avgBuffer, a: UserBean): avgBuffer = {
    b.sum = b.sum + a.age
    b.count = b.count + 1
    b
  }

  //缓冲区的合并数据
  override def merge(b1: avgBuffer, b2: avgBuffer): avgBuffer = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count + b2.count
    b1
  }

  //完成计算
  override def finish(reduction: avgBuffer): Double = {
    reduction.sum.toDouble / reduction.count
  }

  //自定义的类型用product
  override def bufferEncoder: Encoder[avgBuffer] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}

总结:

1. DF(DataFrame) 转DS(DataSet) 过程中一定要加上  import spark.implicits._ ,import spark.implicits._提供了更加强大的功能

如果不添加错误如下:

 

 

2. 读取json文件的时候,自动判定age是Long类型 ,所有本篇文章中的age为Long类型 否则会报错

 

 

user.json 文件中的数据如下:

{"name":"Michael","age":45}
{"name":"Andy", "age":30}
{"name":"Justin", "age":19}

3. json 文件不是传统的正确的json文件,因为spark 是按照行来读取,所以只需要保证每一行是标准的json文件即可,所以使用的json文件显示有错,但是并不影响

4. 读取文件的类型

spark.read.json("path")只能 读取json文件的数据

spark.read.load("path") load 默认读取parquet文件的数据

spark.read.format("json").load("path") 这是通用的方法,通过format来定义文件类型,使用load 来加载数据

5. 保存文件的方式

dataFrame(DF) 或者DataSet(DS)

(1)(DF|DS).write.save("path") 将数据以parquet文件的形式保存到相应路径下

(2)(DF|DS).write.format("json").save("path") 将数据以json文件的形式保存到相应的路径

(1)(2)执行之后存在一个问题:第二次执行的时候需要把path目录下的文件全部删除 否则会报错 为了避免以上情况需要定义保存的模式

(3)DF|DS).write.format("json").mode("append").save("path")  这中方式为通用方式 append 追加 ,在原来原件的基础上追加新的文件,overwrite:用新文件覆盖老文件,默认为default 需要删除原有文件,否则会又报错

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值