In PySpark, both the map() and mapPartitions() functions are used to apply a transformation on the elements of a Dataframe or RDD (Resilient Distributed Dataset). However, there are some differences in their behavior and usage.

map() function:

  • The map() function applies the provided function to each element of the Dataframe or RDD individually. It processes one element at a time.
  • The provided function is applied independently on each partition, which means the function is invoked and executed for each element within the partition.
  • The input and output of the map() function are one-to-one. That is, for each input element, there is a corresponding output element.
  • The map() function is suitable for transformations that do not require access to the entire partition or involve complex logic.

Example:

# Using map() to add 1 to each element
df = spark.createDataFrame([(1,), (2,), (3,)], ["value"])
df_mapped = df.rdd.map(lambda row: row[0] + 1)
df_mapped.collect()

Output: [2, 3, 4]

mapPartitions() function:

  • The mapPartitions() function applies the provided function to each partition of the Dataframe or RDD. It processes a partition as a whole, rather than individual elements.
  • The provided function receives an iterator of elements within a partition and returns an iterator of output elements. It allows performing operations that require accessing the entire partition or maintaining state across elements within a partition.
  • The input and output of the mapPartitions() function can have different cardinalities. That is, the number of output elements can be different from the number of input elements.
  • The mapPartitions() function can be used for optimizations, as it reduces the overhead of invoking the provided function for each element individually.

Example:

# Using mapPartitions() to multiply all elements within each partition by 10
df = spark.createDataFrame([(1,), (2,), (3,)], ["value"])
df_mapped = df.rdd.mapPartitions(lambda iterator: (row[0] * 10 for row in iterator))
df_mapped.collect()

Output: [10, 20, 30]

In summary, the map() function is suitable for applying a transformation on each individual element, while the mapPartitions() function is useful when you need to process a partition as a whole or perform operations that require accessing the entire partition.