mapPartitions解析
底层源码
/**
* Return a new RDD by applying a function to each partition of this RDD.
*
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = withScope {
val cleanedF = sc.clean(f)
new MapPartitionsRDD(
this,
(context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(iter),
preservesPartitioning)
}
从源码中我们看出需要传入的参数是一个迭代器,同时也需要返回一个迭代器类型,和一个布尔值.如果我们需要获取数据是属于哪一个分区的话,我们通过下面的方法获取:
@Test
def showData: Unit ={
sc.parallelize(Seq(3, 8, 9, 4, 2), 2).mapPartitions(
item => {
val string: String = UUID.randomUUID().toString
// println(string)
item.foreach(it => println(s"${string}----${it}"))
item
}
).collect()
}
得到结果:
01cb417e-08f5-4c21-b5c1-c83633fcf4a9----3
2969aa7e-fd92-4581-8375-6431b079e416----9
01cb417e-08f5-4c21-b5c1-c83633fcf4a9----8
2969aa7e-fd92-4581-8375-6431b079e416----4
2969aa7e-fd92-4581-8375-6431b079e416----2