map将函数作用到数据集的每一个元素上,生成一个新的分布式的数据集(RDD)返回
map函数的源码:
def map(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each element of this RDD. >>> rdd = sc.parallelize(["b", "a", "c"]) >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning)
map将每一条输入执行func操作并对应返回一个对象,形成一个新的rdd,如源码中的rdd.map(lambda x: (x, 1) --> [('a', 1), ('b', 1), ('c', 1)]
flatMap会先执行map的操作,再将所有对象合并为一个对象,返回值是一个Sequence
flatMap源码:
def flatMap(self, f, preservesPartitioning=False): """ >>> rdd = sc.parallelize([2, 3, 4]) >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect()) [1, 1, 1, 2, 2, 3] >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect()) [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning)
注意:flatMap将输入执行func操作时,对象必须是可迭代的
map与flatMap的区别:
1 from pyspark import SparkConf, SparkContext 2 3 conf = SparkConf() 4 sc = SparkContext(conf=conf) 5 6 7 def func_map(): 8 data = ["hello world", "hello fly"] 9 data_rdd = sc.parallelize(data) 10 map_rdd = data_rdd.map(lambda s: s.split(" ")) 11 print("map print:{}".format(map_rdd.collect())) 12 13 14 def func_flat_map(): 15 data = ["hello world", "hello fly"] 16 data_rdd = sc.parallelize(data) 17 flat_rdd = data_rdd.flatMap(lambda s: s.split(" ")) 18 print("flatMap print:{}".format(flat_rdd.collect())) 19 20 21 func_map() 22 func_flat_map() 23 sc.stop()
执行结果:
map print:[['hello', 'world'], ['hello', 'fly']] flatMap print:['hello', 'world', 'hello', 'fly']
可以看出,map对 "hello world", "hello fly"这两个对象分别映射为['hello', 'world'], ['hello', 'fly'],而flatMap在map的基础上做了一个合并操作,将这两个对象合并为一个['hello', 'world', 'hello', 'fly'],这就造就了flatMap在词频统计方面的优势。