State of PySpark, tips and possible improvements

Written by: Igor Korotach

What is Spark?

Apache Spark

Apache Spark is an open-source distributed general-purpose cluster-computing framework. Spark provides an interface for programming entire clusters with implicit data parallelism and fault tolerance.

PySpark

(And its internals)

What is PySpark?

  • Python interface to general Spark Core written in Java/Scala
  • Same technique used to interface with other languages (R, Julia, etc.)
  • Mature and integrates well with Python ecosystem

PySpark internals

In the Python driver program, SparkContext uses Py4J to launch a JVM and create a JavaSparkContext. Py4J is only used by the driver for local communication between the Python and Java SparkContext objects; large data transfers are performed through a different mechanism. (sssshhhh HDFS)

A bit more detail!

What's so bad about that?

Well, multiple things!

  • Copying data is expensive
  • End-2-end serialization is expensive
  • Error messaging is... a joke to say the least
  • Memory consumption/management is hell if you are working in YARN e.g.

Here comes the new kid on the block...

Apache Arrow

Apache Arrow is a cross-language development platform for in-memory data. It specifies a standardized language-independent columnar memory format for flat and hierarchical data, organized for efficient analytic operations on modern hardware.

Apache Arrow's premise

Solving Spark copying problem

Simple use case

# Setup the spark client
from pyspark.sql import SparkSession
warehouseLocation = "/wh"
spark = SparkSession\
	.builder.appName("demo")\
	.config("spark.sql.warehouse.dir", warehouseLocation)\
	.enableHiveSupport()\
	.getOrCreate()
    
# Create test Spark DataFrame
from pyspark.sql.functions import rand
df = spark.range(1 << 22).toDF("id").withColumn("id", rand())
df.printSchema()

# Benchmark time
%time pdf = df.toPandas()
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
%time pdf = df.toPandas()
pdf.describe()

Results

There is still a copying problem :(

Case 1: PyTorch

(workflow as is)

Step 1: Prepare your worker

# Do the imports
import os
import ...


# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "2048")

# Enable CUDA
cuda = True

use_cuda = cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

Step 2: Prepare your model

bc_model_state = sc.broadcast(models.resnet50(pretrained=True).state_dict())

def get_model_for_eval():
  """Gets the broadcasted model."""
  model = models.resnet50(pretrained=True)
  model.load_state_dict(bc_model_state.value)
  model.eval()
  return model

Step 3: Load data in Spark DataFrame

files_df = spark.createDataFrame(
  map(lambda path: (path,), files), ["path"]
).repartition(10)
display(files_df.limit(10))

Step 3: Run model inference

class ImageDataset(Dataset):
  def __init__(self, paths, transform=None):
    self.paths = paths
    self.transform = transform
  def __len__(self):
    return len(self.paths)
  def __getitem__(self, index):
    image = default_loader(self.paths[index])
    if self.transform is not None:
      image = self.transform(image)
    return image
    

(create a custom DataSet)

Step 3: Run model inference

def predict_batch(paths):
  transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  ])
  images = ImageDataset(paths, transform=transform)
  loader = torch.utils.data.DataLoader(images, batch_size=500, num_workers=8)
  model = get_model_for_eval()
  model.to(device)
  all_predictions = []
  with torch.no_grad():
    for batch in loader:
      predictions = list(model(batch.to(device)).cpu().numpy())
      for prediction in predictions:
        all_predictions.append(prediction)
  return pd.Series(all_predictions)

(Define the function for model inference)

Step 4: Run model inference

# Set your model
predict_udf = pandas_udf(ArrayType(FloatType()),
			PandasUDFType.SCALAR)(predict_batch)

# Make predictions.
predictions_df = files_df.select(col('path'), 
			predict_udf(col('path')).alias("prediction"))
predictions_df.write.mode("overwrite").parquet(output_file_path)

# Load and check the prediction results.
result_df = spark.read.load(output_file_path)
display(result_df)

(Wrap the function in Pandas UDF and run)

PyTorch workflow

PyTorch workflow optimized

Good news!

1) Demand for Java in PyTorch

https://github.com/pytorch/pytorch/issues/6570

2) PyTorch JNI bindings as gradle library

https://github.com/pytorch/pytorch/issues/28986

3) This is expected to arrive in

PyTorch 1.4!

Even better news!

Module mod = Module.load("demo-model.pt1");
Tensor data =
    Tensor.fromBlob(
        new int[] {1, 2, 3, 4, 5, 6}, // data
        new long[] {2, 3} // shape
        );
IValue result = mod.forward(IValue.from(data), IValue.from(3.0));
Tensor output = result.toTensor();
System.out.println("shape: " + Arrays.toString(output.shape()));
System.out.println("data: " + Arrays.toString(output.getDataAsFloatArray()));

Case 2: Tensorflow

Case 2: Tensorflow

(possible optimization)

How does it look like?

Load model

Path modelPath = Paths.get(LoadTensorflowModel.class.getResource("saved_model.pb").toURI());
byte[] graph = Files.readAllBytes(modelPath);

try (Graph g = new Graph()) {
    g.importGraphDef(graph);
    //open session using imported graph
    try (Session sess = new Session(g)) {
        float[][] inputData = {{4, 3, 2, 1}};
        // We have to create tensor to feed it to session,
        // unlike in Python where you just pass Numpy array
        Tensor inputTensor = Tensor.create(inputData, Float.class);
        float[][] output = predict(sess, inputTensor);
        for (int i = 0; i < output[0].length; i++) {
            System.out.println(output[0][i]);//should be 41. 51.5 62.
        }
    }
}

How does it look like?

Run model

private static float[][] predict(Session sess, Tensor inputTensor) {
    Tensor result = sess.runner()
            .feed("input", inputTensor)
            .fetch("not_activated_output").run().get(0);
    float[][] outputBuffer = new float[1][3];
    result.copyTo(outputBuffer);
    return outputBuffer;
}

Are there any problems left?

Are there any problems optimization left?

CPU Zero copy is a lie!

Though Apache Arrow provides an optimal storage and conversion format it doesn't cover the data transfer problem. Data is still sent over network. One solution would be using a shared memory object store like

Plasma Store.

 

Why isn't this done yet?

Well, a few reasons...

  • Plasma store only work on Linux and Mac

  • It is an extra network component that needs to be started, synchronised and can fail

  • It adds another layer of complexity

  • Plasma store doesn't have API (or any plans to add it) for all the Spark clients (namely R)

CPU <--> GPU transfer dilemma

Are there any solutions?

Are there any solutions?

Well, technically, yes! 

Are there any solutions?

But practically, no! :( 

Why?

 

  1. You are using a platform with iGPU (integrated GPU; e.g. Jetson)
  2. The platform has APIs to handle manual memory access patterns (e.g. CUDA)
  3. Tensorflow/PyTorch has bindings to these APIs (e.g. https://github.com/tensorflow/tensorflow/issues/31441)
  4. A ton of extra memory management work...

Technically, zero copy situation is achieved if ALL of the following applies:

The granted speed up is 10%-58%, memory usage reduction is 50% on PX2 and TX2/AGX for example task (autoware and drone obstacle detection)

What do we do until then?

You could help us achieve this brighter future

Thanks for your attention. You've been awesome!

Questions?

State of PySpark, tips and possible improvements

By Igor Korotach

State of PySpark, tips and possible improvements

  • 267