Apache Spark : Unit testing avec JUnit en Java

Cet article est le deuxième d’une série sur l’utilisation de Apache Spark en Java, voir Apache Spark : Use cases pour développeurs Java pour la partie précédente.

Nous allons traiter le sujet des tests unitaires pour Spark avec JUnit, de la création d’un Dataset de test, de l’assertion d’un Dataset Spark, et de l’utilisation des @Rule JUnit4 ou Extension JUnit5 pour la création du SparkSession partagé. Nous allons même voir comment faire de l’assertion distribuée avec Spark !

Le code de cet article est tiré d’exemples d’usage de Spark en Java sur notre github et est associé à une conférence que nous donnons sur le sujet.

Tester c’est douter : un exemple de notre code

Reprenons l’exemple de l’article précédent : comment trouver la moyenne des prix par formule pour un assureur donné ?

public class TarifsRun {

    private static final SparkSession spark = SparkSession.builder()
            .master("local[*]")
            .getOrCreate();

    public static void main(String[] args) {
        Dataset<Row> tarifs = spark.read()
                .option("header", true)
                .option("inferSchema", true)
                .csv(PATH_TARIFS_CSV)
                .filter((FilterFunction<Row>) value ->
                        value.<String>getAs("assureur").equals("Mon SUPER assureur"))
                .groupBy("formule")
                .agg(avg("prime").as("average"))
                .orderBy(desc("average"));
        tarifs.show();
    }

}

Il faut d’abord rendre le code testable, en isolant le code à tester. On ne cherche pas à tester le chargement CSV, alors on extrait une méthode s’appelant averagePrime en bon français, qui prend en paramètre l’extraction CSV et qui retourne le Dataset traité.

public static Dataset<Row> averagePrime(Dataset<Row> tarifs) {
     return tarifs
            .filter((FilterFunction<Row>) value ->
                    value.<String>getAs("assureur").equals("Mon SUPER assureur"))
            .groupBy("formule")
            .agg(avg("prime").as("average"))
            .orderBy(desc("average"));
}

La création d’un Dataset de test

Notre code de test, exécuté par JUnit, s’écrit comme un test normal, où on commence par initialiser une session locale Spark.

private static final SparkSession spark = SparkSession.builder()
        .master("local[*]")
        .getOrCreate();

private Dataset<Row> tarifs;

Ensuite, on cherche à créer notre Dataset de test appelé tarifs. La méthode à utiliser pour créer un Dataset de toute pièce est Spark#createDataFrame, qui prend une liste de Row, qu’on pourra créer avec RowFactory. On écrit donc un @Before qui s’occupe de la création :

@Before
public void before() {
    List<Row> rows = Arrays.asList(
            RowFactory.create("F1", 50d, "Mon SUPER assureur"),
            RowFactory.create("F1", 100d, "Mon SUPER assureur"),
            RowFactory.create("F2", 70d, "Mon SUPER assureur"));

    StructField formule = new StructField("formule", StringType, false, empty());
    StructField prime = new StructField("prime", DoubleType, false, empty());
    StructField assureur = new StructField("assureur", StringType, false, empty());
    StructType structType = new StructType(new StructField[]{formule, prime, assureur});

    tarifs = spark.createDataFrame(rows, structType);
}

Il faut remarquer qu’on n’a pas besoin de créer toutes les colonnes du Dataset de base. Celui d’origine contient également un « uid », une « date », etc. mais ces colonnes ne sont pas utilisées dans la méthode averagePrime, on peut donc les ignorer dans notre Dataset.

La méthode RowFactory#create prend un vararg, dont chaque élément correspond à une valeur typée de colonne. Le type et le nom de ces colonnes est donné par un StructField, et le type d’un Dataset est donné par un StructType, qui contient la liste des colonnes.

L’écriture des assertions

Ensuite, il faut appeler la méthode à tester et vérifier le retour, soit un Dataset contenant les primes moyennes, ordonné par prix.

    @Test
    public void should_average_tarif_return_correct_average() {
        Dataset<Row> averagePrime = TarifsRun.averagePrime(tarifs);

        assertEquals(2, averagePrime.count());
        assertEquals("F1", averagePrime.first().getAs("formule"));
        assertEquals(75, (double) averagePrime.first().<Double>getAs("average"));
    }

On remarque que les assertions sont des assertions JUnit normales, avec quelques subtilités sur les types. L’utilisation de Dataset#count et Dataset#first permet de facilement récupérer les valeurs voulues. On aurait pu aussi utiliser Dataset#takeAsList qui retourne les « n » premières lignes sous forme de liste de Row.

Attention cela se complique si vous voulez itérer sur le Dataset pour en tester les valeurs, vous pouvez faire un Dataset#foreach et mettre du code d’assertion JUnit dans la fonction. Mais puisque ce code s’exécute sur le worker, il faut s’assurer que la fonction soit sérialisable. Par exemple, ce code ne compile pas, avec une exception de type java.io.NotSerializableException :

averagePrime.foreach(new ForeachFunction<Row>() {
    @Override
    public void call(Row row) throws Exception {
        assertNotNull(row.getAs("formuleReadable"));
    }
});

Ceci est du au fait que le SAM (Single Abstract Method) de type ForeachFunction est compilé comme une classe anonyme interne à com.lesfurets.spark.examples.TarifsRunTest, qui elle n’est pas sérialisable. Il est possible d’ajouter un implements Serializable sur notre classe de test, mais on a d’autres alternatives, comme utiliser une lambda :

// On teste carrément sur le worker, remarquez que c'est le même code que l'exemple précédent, 
// mais avec une lambda au lieu d'une SAM
averagePrime.foreach((ForeachFunction<Row>) row -> assertNotNull(row.getAs("formuleReadable")));

Dans ce deuxième exemple, le code d’assertion est exécuté sur le worker. Autrement dit, on fait de l’assertion JUnit distribuée ! Plus précisément, c’est de l’assertion parallèle parce qu’on est en session locale (voir « Tester sur un vrai cluster de test »).

Démarrer un SparkSession pour tous les tests

L’exemple précédent démarre Spark (à l’initialisation du SparkSession) à chaque début de test, et si vous en avez beaucoup, cela augmente le temps total de vos tests unitaires inutilement. Le mieux est de démarrer une fois Spark et de réutiliser la même instance de SparkSession.

Avec JUnit4, cela revient à créer une classe qui étend ExternalResource et qui initialise le SparkSession au premier appel à @Before, puis retourne la même instance ensuite. Avec JUnit5, il faut écrire une extension qui fait le même travail. Nous avons déjà fait le travail pour JUnit4, dans la classe com.lesfurets.spark.junit4.rule.SparkTestSession, et pour JUnit5, dans l’annotation com.lesfurets.spark.junit5.extension.SparkTest. Ces classes sont disponibles sur notre github.

Avec l’extension @SparkTest et JUnit5, le code précédent devient :

@SparkTest
public class TarifsRunTest {

    public SparkSession spark;

}

Le SparkSession déjà initialisé (instance partagée par tous les tests) sera automatiquement injecté dans le field « spark ». Cela a aussi comme avantage d’éviter de démarrer trop de workers : le nombre est fixé au nombre de cœurs lors du démarrage de Spark.

Tester sur un vrai cluster de test

On remarque que les tests utilisent tous la session Spark local[*], ce qui veut dire que le driver et les workers sont lancés sur la même JVM. Cela est pratique pour faire quelques tests, mais il y a plusieurs avantages à utiliser un cluster de test :

  • cela permet de lancer des tests plus lourds sur un cluster dédié sans plomber la machine de build : par exemple pour des assertions distribuées
  • cela permet de tester dans un environnement semblable à la production : le driver ne sera pas sur la même JVM que les workers

Ce dernier point est important à tester, puisque dans certains cas, les objets transférés entre les workers ne seront pas sérialisables, et vous ne le saurez qu’au runtime, dans un vrai cluster.

Le plus simple est de monter un petit cluster Spark dédié aux tests, accessibles depuis vos machines de build.

Conclusion

L’API des DataFrame Spark est facilement testable, et on remarque même la possibilité de faire de l’assertion distribuée avec Spark + JUnit et ça, c’est cool.

N’oubliez pas que ce code, ainsi que d’autres exemples, sont disponibles sur notre github.

import com.lesfurets.spark.junit5.extension.SparkTest;
import org.apache.spark.api.java.function.ForeachFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import org.junit.jupiter.api.*;

import java.util.*;

import static org.apache.spark.api.java.JavaSparkContext.fromSparkContext;
import static org.apache.spark.sql.types.DataTypes.*;
import static org.apache.spark.sql.types.Metadata.empty;
import static org.junit.jupiter.api.Assertions.*;

@SparkTest
public class TarifsRunTest {

    public SparkSession spark;

    private Dataset<Row> tarifs;

    @BeforeEach
    public void before() {
        List<Row> rows = Arrays.asList(
                RowFactory.create(1, 50d, "Mon SUPER assureur"),
                RowFactory.create(1, 100d, "Mon SUPER assureur"),
                RowFactory.create(2, 70d, "Mon SUPER assureur"));

        StructField formule = new StructField("formule", IntegerType, false, empty());
        StructField prime = new StructField("prime", DoubleType, false, empty());
        StructField assureur = new StructField("assureur", StringType, false, empty());
        StructType structType = new StructType(new StructField[]{formule, prime, assureur});

        tarifs = spark.createDataFrame(rows, structType);
    }

    @Test
    public void should_average_tarif_return_correct_average() {
        Dataset<Row> averagePrime = TarifsRun.averagePrime(tarifs);

        averagePrime.foreach((ForeachFunction<Row>) row -> assertNotNull(row.getAs("formuleReadable")));

        assertEquals(2, averagePrime.count());
        assertEquals(1, (int) averagePrime.first().<Integer>getAs("formule"));
        assertEquals("tiers bdg", averagePrime.first().getAs("formuleReadable"));
        assertEquals(75, averagePrime.first().<Double>getAs("average"), 0.001);
    }

}