注意:需要引入spark-hive的依赖
目录
package com.spark.self
import org.apache.spark.sql.SparkSession
object UdfDemo {
def main(args: Array[String]): Unit = {
newUdf
newDfUdf
chuangTi
}
/**
* 根据年龄大小返回是否成年 成年:true,未成年:false
*/
def isAdult(age: Int) = {
if (age < 18) {
false
} else {
true
}
}
/**
* 根据名字计算姓名的的长度
*/
def isLength(name: String) = {
name.length
}
def con(name: String, age: Int) = {
age + ":" + name
}
def con1(name: String, age: Int) = {
name + ":" + age
}
/**
* 新版本(Spark2.x)Spark Sql udf示例
*/
def newUdf() {
//spark初始化
val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
// 注册一张user表
userDF.createOrReplaceTempView("user")
//注册自定义函数(通过匿名函数)
spark.udf.register("strLen", (str: String) => str.length())
spark.udf.register("NameAddAge", (x: String, y: Int) => x + ":" + y)
//注册自定义函数(通过实名函数)
spark.udf.register("isAdult", isAdult _)
spark.udf.register("isLength", isLength _)
spark.udf.register("con", con(_: String, _: Int))
spark.sql("select *," +
"isLength(name) as name_len_shi," +
"strLen(name) as name_len_ni," +
"NameAddAge(name,age) as NameAge," +
"con(name,age) as AgeName," +
"isAdult(age) as isAdult from user").show
//关闭
spark.stop()
}
/**
* 新版本(Spark2.x)DataFrame udf示例
*/
def newDfUdf() {
val spark = SparkSession.builder().appName("newDfUdf").master("local[*]").getOrCreate()
// 构造测试数据,有两个字段、名字和年龄
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
import org.apache.spark.sql.functions._
//注册自定义函数(通过匿名函数)
val strLen = udf((str: String) => str.length())
val ageAddName = udf((str: String, str1: Int) => str1 + ":" + str)
//注册自定义函数(通过实名函数)
val udf_isAdult = udf(isAdult _)
val udf_nameAddAge = udf(con1(_: String, _: Int))
//通过withColumn添加列
userDF.withColumn("name_len", strLen(col("name")))
.withColumn("isAdult", udf_isAdult(col("age")))
.withColumn("ageAddName", ageAddName(col("name"), col("age")))
.withColumn("nameAddAge", udf_nameAddAge(col("name"), col("age")))
.show
//通过select添加列
userDF.select(col("*"),
strLen(col("name")) as "name_len",
udf_isAdult(col("age")) as "isAdult",
ageAddName(col("name"), col("age")) as "ageAddName",
udf_nameAddAge(col("name"), col("age")) as "nameAddAge"
).show
//关闭
spark.stop()
}
/**
* 窗体函数
*/
def chuangTi(): Unit = {
//窗体函数必须加入hive 支持.enableHiveSupport()
val spark = SparkSession.builder().appName("newDfUdf").master("local[*]").enableHiveSupport().getOrCreate()
val scoreData = Array(
("a", 1, 80),
("b", 1, 78),
("c", 1, 95),
("d", 2, 74),
("e", 2, 92),
("f", 3, 99),
("g", 3, 99),
("h", 3, 45),
("i", 3, 55),
("j", 3, 78))
val scoreDF = spark.createDataFrame(scoreData).toDF("name", "class", "score")
scoreDF.createOrReplaceTempView("score")
//求每个班最高成绩学生的信息
/**
* 1,2,2,4,5,6.。。。。这是rank()的形式
*
* 1,2,2,3,4,5,。。。。这是dense_rank()的形式
*
* 1,2,3,4,5,6.。。。。。这是row_number()涵数形式
*/
// //方式 一
spark.sql("select name,class,score,rank() over(partition by class order by score desc) rank from score").show()
spark.sql("select name,class,score,dense_rank() over(partition by class order by score desc) rank from score").show()
spark.sql("select name,class,score,row_number() over(partition by class order by score desc) rank from score").show()
//方式二 因为只根据class 进行group by 所以只能展示一个列
spark.sql("select class, max(score) max from score group by class").show()
//想要展示多个列使用以下方式
spark.sql("select a.name, b.class, b.max from score a, " +
"(select class, max(score) max from score group by class) as b " +
"where a.score = b.max").show()
//求班级的总分
spark.sql("select class,sum(score) over(partition by class order by score desc) sumNumber from score").show()
spark.sql("select class,sum(score) totalScore from score group by class").show()
//求每个班级的平均分
spark.sql("select class,avg(score) avgNumber from score group by class").show()
}
}
第二部分:自定义聚合函数(弱类型)
package com.spark.self
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
object UDAFDemo {
def main(args: Array[String]): Unit = {
//spark初始化
val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()
val userData = Array(("Leo", 16), ("Marry", 21), ("Jack", 14), ("Tom", 18))
//创建测试df
val userDF = spark.createDataFrame(userData).toDF("name", "age")
// 注册一张user表
userDF.createOrReplaceTempView("user")
//创建聚合函数对象
val udaf = new MyAgeAvgFunction
//将聚合函数注册到spark中
spark.udf.register("ageAvg", udaf)
spark.sql("select ageAvg(age) as avg from user").show()
}
}
//作用自定义聚合函数
//集成UserDefinedAggregateFunction类
class MyAgeAvgFunction extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
//计算时的数据结构 比如平均值需要用到总数sum 和 个数 count
override def bufferSchema: StructType = {
new StructType().add("sum", LongType).add("count", LongType)
}
//函数返回的类型
override def dataType: DataType = DoubleType
//函数是否稳定
override def deterministic: Boolean = true
//计算之前缓冲区的初始化(sum 和count的初始化)
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L //0 代表sum 1代表count
buffer(1) = 0L
}
//更新查询结果更新缓冲区的数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
//将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
//count
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
//计算
override def evaluate(buffer: Row): Any = {
(buffer.getLong(0) / buffer.getLong(1)).toDouble
}
}
第三部分:自定义聚合函数(强类型)
package com.spark.self
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql._
import org.apache.spark.sql.SparkSession
object UDAFClassDemo {
def main(args: Array[String]): Unit = {
//spark初始化
val spark = SparkSession.builder().appName("newUdf").master("local[*]").getOrCreate()
import spark.implicits._
//创建聚合函数对象
val udaf = new MyAgeAvgClassFunction
//将聚合函数转换成查询列
val avgCol: TypedColumn[UserBean, Double] = udaf.toColumn.name("avgAge")
// val userDF: DataFrame = spark.read.json("in/user.json")
// val userDF: DataFrame = spark.read.load("in/user.json")//load 读取Parquet文件
val userDF: DataFrame = spark.read.format("json").load("in/user.json")
//如果想使用load 读取json文件就需要使用format定义文件格式为json文件就可以了
//load 读取Parquet文件
val userDs: Dataset[UserBean] = userDF.as[UserBean]
//应用函数
userDs.select(avgCol).show()
//把数据保存到相应的路劲下默认保存为Parquet文件
// userDs.write.save("out/")
//如果想保存为json文件就需要定义文件类型
// userDs.write.format("json").save("out/")
//使用mode("append")定义保存数据的模式,就不需要删除文件直接执行,追加
//append 在老的文件基础上追加新的文件 ,overwite 用新的文件覆盖老的
// userDs.write.format("json").mode("append").save("out/")
userDs.write.format("json").mode("overwrite").save("out/")
//释放资源
spark.stop()
}
}
//作用自定义聚合函数(强类型)
//集成Aggregator类,设定泛型
//定义输入的参数
case class UserBean(name: String, age: Long)
//定义缓冲区的参数
case class avgBuffer(var sum: Long, var count: Long)
class MyAgeAvgClassFunction extends Aggregator[UserBean, avgBuffer, Double] {
//初始化
override def zero: avgBuffer = {
avgBuffer(0, 0)
}
//聚合数据
override def reduce(b: avgBuffer, a: UserBean): avgBuffer = {
b.sum = b.sum + a.age
b.count = b.count + 1
b
}
//缓冲区的合并数据
override def merge(b1: avgBuffer, b2: avgBuffer): avgBuffer = {
b1.sum = b1.sum + b2.sum
b1.count = b1.count + b2.count
b1
}
//完成计算
override def finish(reduction: avgBuffer): Double = {
reduction.sum.toDouble / reduction.count
}
//自定义的类型用product
override def bufferEncoder: Encoder[avgBuffer] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
总结:
1. DF(DataFrame) 转DS(DataSet) 过程中一定要加上 import spark.implicits._ ,import spark.implicits._提供了更加强大的功能
如果不添加错误如下:
2. 读取json文件的时候,自动判定age是Long类型 ,所有本篇文章中的age为Long类型 否则会报错
user.json 文件中的数据如下:
{"name":"Michael","age":45}
{"name":"Andy", "age":30}
{"name":"Justin", "age":19}
3. json 文件不是传统的正确的json文件,因为spark 是按照行来读取,所以只需要保证每一行是标准的json文件即可,所以使用的json文件显示有错,但是并不影响
4. 读取文件的类型
spark.read.json("path")只能 读取json文件的数据
spark.read.load("path") load 默认读取parquet文件的数据
spark.read.format("json").load("path") 这是通用的方法,通过format来定义文件类型,使用load 来加载数据
5. 保存文件的方式
dataFrame(DF) 或者DataSet(DS)
(1)(DF|DS).write.save("path") 将数据以parquet文件的形式保存到相应路径下
(2)(DF|DS).write.format("json").save("path") 将数据以json文件的形式保存到相应的路径
(1)(2)执行之后存在一个问题:第二次执行的时候需要把path目录下的文件全部删除 否则会报错 为了避免以上情况需要定义保存的模式
(3)DF|DS).write.format("json").mode("append").save("path") 这中方式为通用方式 append 追加 ,在原来原件的基础上追加新的文件,overwrite:用新文件覆盖老文件,默认为default 需要删除原有文件,否则会又报错