daita-technologies / ai-tools

AI-based tools for the DAITA platform.

Home Page:http://app.daita.tech

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to conduct train/validation/test data split

pcaversaccio opened this issue · comments

Generally, the main question an ML engineer must ask is: How should I pick my train, validation, and test set?

Naively, one could just manually split the dataset into three chunks (as it's currently done in the MVP). The problem with this approach is that we humans are very biased and this bias would get introduced into the three sets. In academia, we learn that we should pick them randomly (an example built-in methods from scikit-learn). A random split into the three sets guarantees that all three sets follow the same statistical distribution. And that's what we want since ML is all about statistics. Deriving the three sets from completely different distributions would yield some unwanted results. There is not much value in training a model on pictures of cats if we want to use it to classify flowers.

However, the underlying assumption of a random split is that the initial dataset already matches the statistical distribution of the problem we want to solve. That would mean that for problems such as autonomous driving the assumption is that our dataset covers all sorts of cities, weather conditions, vehicles, seasons of the year, special situations, etc. As you might think this assumption is actually not valid for most practical deep learning applications. Whenever we collect data using sensors in an uncontrolled environment we might not have the desired data distribution. What we're looking for is the research area around finding and dealing with domain gaps, distributional shifts, or data drift. The ultimate goal should be to focus on building a robust enough model to handle such domain gaps. This approach is focusing on building models for out-of-distribution tasks.

In machine learning, we refer to out-of-distribution whenever our model has to perform well in a situation where the new input data is from a different distribution than the training data. Going back to our autonomous driving example from before, we could say that for a model that has only been trained on sunny Hanoi weather, doing predictions in Europe is out of distribution.

Now, how should we do the split of the dataset for such a task? Since we collected the data using different sensors we also might have additional information about the source for each of the samples (a sample could be an image, lidar frame, video, etc.).
We can solve this problem by splitting the dataset in the following way:

  • we train on a set of data from cities in list A
  • and evaluate the model on a set of data from cities in list B

There is a great article from Yandex research about their new dataset to tackle distributional shifts in datasets.