How to Train practically any Model from practically any Data with TensorFlow
Objectives
This post will guide you on how to take your data (in a CSV file) to a trained TensorFlow model of your choosing.
You’re not going to find any tricks or hacks here. The title to this blog post is so general because the TensorFlow developers have created a great API for importing data and training standard models. If you follow all the suggestions of the official TensorFlow docs, you should come to the same conclusions I do here.
It may tempting to quickly write a script that works for your current data and current task, but if you take a little extra time and write generalizable code, you will save yourself headaches in the future. The instructions here will help you easily scale to different datasets and different model architectures.
Pre-requisites
- A working, new version of TensorFlow installed.
- Your data in CSV format. The reason I choose CSV data as the starting point is that almost any data can be formatted as a CSV file. Getting your raw data to a CSV file is on you, but once you get there, the rest is smooth sailing:) From CSV data, I show you how to get your data into
tfrecords
format, which is the prefered TF data format. So, if your data is already intfrecords
, you’re already ahead of the curve!
Install TensorFlow
Just follow the official installation instructions!
Get Data in CSV
To ground this post in a concrete example, below is my own labeled data in CSV. Each training data example is represented as a single row in the CSV file, where the first column represents the label (an integer), and all the following columns contain the features for that example (floating point numbers).
The labels in the first column represent categories of speech sounds, for example, label 45
might be the vowel [oh]
and label 7
might be the consonant [k]
. The features shown in column-two onward correspond to amplitudes at different frequency ranges. In a nutshell, this is speech data, where a snippet of audio (features) has been labeled with a language sound (labels).
Here’s what four lines of my data CSV file look like (where the delimiter is a single space):
95 21.50192 -2.16182 -1.591426 0.06965831 0.6690025 ... -0.7368361 -1.385849 0.7551874 -0.8878949 -0.4799456
7 22.23698 -1.177924 -1.368747 -0.6289141 0.009075502 ... -0.9235415 -1.74792 0.2629939 -2.119366 -0.539937
45 22.83421 -0.9043457 -1.591426 -0.816999 -0.3035215 ... -0.5301266 -1.456303 -0.1479924 -1.641482 -0.04098308
27 -0.9376022 -0.05841255 0.3308391 -0.7141842 -0.3867566 ... -1.263647 23.4316 -0.0009118451 -1.035212 -1.635385
Pretty simple, right? One training example is one line in the CSV file.
If your data isn’t in this kind of CSV format, you’re going to have to spend a little time to get it here. The most important point is that you need one training example per line, and you should know exactly where each part of the example is located. For my example, I know that the label is the first column, and all the following columns are my features. You also must know how each label/feature is represented. For my case, all the labels are integers, and all the features are floating point numbers. You might have text-based labels or features (e.g. words from text), or you could have categorical features (e.g. you have a feature for "color"
that you’ve coded as integers 1
through 5
).
Whatever the case, you need to know exactly:
- Where your data is (i.e. which column)
- How your data is coded (e.g. float vs. integer vs. text)
- What your data means (e.g. the integer entry
43
in column 5 corresponds to the colorblue
)
The last point is very important, because you might have integers whose numerical distance doesn’t correspond to anything meaningful (e.g. the distance between 3
and 7
means nothing if 3
is "orange"
and 7
is "magenta"
). On the other hand, you might have integers where the distance between them is very important (e.g. the score on a test 99
is much better than a grade of 59
). The distance between test grades is meaningful, but the distance between colors is not.
In what follows, you have to decide how to represent your values, and whether or not their distances matter.
Convert CSV to TFRecords
TFRecords
is the preferred file format for TensorFlow. These tfrecords
files take up a lot of space on disk, but they can be easily sharded and processed across machines, and the entire TensorFlow pipeline is optimized with tfrecords
in mind.
To work with tfrecords
data, you have to first format your CSV data using TensorFlow itself. We have to read in the CSV file one example at a time, and format it as a tf.train.Example
example, and then print that example to a file on disk. Each tf.train.Example
stores information about that particular example via so-called features
, where these features
can be anything (including the target label!). You will store each Example
’s feature
as an item in a dictionary, where the key should be descriptive. You can see in the following example I have chosen the keys "label"
and "feats"
to be make sure I won’t mix them up.
Below is an example Python script to read in a .csv
data file and save to a .tfrecords
file. You can find the original version of the following csv-to-records.py
here. There are faster ways to do this (i.e. via parallelization), but I want to give you working code which is as readable as possible.
There’s a good amount of resources on tfrecords
out there, check out the official docs on reading data, Python-IO, and importing data.
Pat on the Back
If you’ve gotten to this point, you have successfully converted your data and saved it as TFRecords
format. Take a pause and pat yourself on the back, because you’ve accomplished the most time-consuming and boring part of machine learning: data formatting.
Now that you have your data in a format TensorFlow likes, we can import that data and train some models. Before we jump straight into training code, you’ll want a little background on TensorFlow’s awesome APIs for working with data and models: tf.data
and tf.estimator
.
Datasets and Estimators
The official TensorFlow docs push hard for you to use their Dataset and Estimator APIs. In general, if the docs explicitly tell you there is a preferred way to do something, you should do that because all the newest features will surely work for this format but maybe not others.
Dataset API
tf.data.Dataset
The Dataset
Class allows you to easily import, shuffle, transform, and batch your data. The Dataset
API makes any pre-processing operation on your data just another part of the pipeline, and it’s optimized for large, distributed datasets. Your entire pre-processing pipeline can be as simple as this:
In the above definition of dataset
, you can see there’s a line where you point TensorFlow to your data on disc, and read the data via tf.data.TFRecordDataset
. The .shuffle()
and .batch()
functions are optional, but you will need the .map()
function.
The .map()
function provides the methods for parsing your data into meaningful pieces like “labels” and “features”. However, .map()
is a super general function, and it doesn’t know anything about your data, so we have to pass a special parsing function which .map()
then applies to the data. This parser
function is probably the main thing you have to create for your own dataset, and it should exactly mirror the way you saved your data to TFRecords above with the tf.Example
object (in the data formatting section above). Read more about parser functions in the official docs.
The above is one of the simplest ways to load, shuffle, and batch your data, but it is not the fastest way. For tips on speeding this stage up, take a look here and here.
Here’s an example of such a parser
function:
To get into the details of this function and how you can define one for your data, take a look at the official parse function docs. Remember that if you have labeled training data, the features
definition above includes the data features (feats
) as well as the labels (label
). If you’re doing something like k-means clustering (where labels aren’t used), you won’t return a label.
Estimator API
tf.estimator.Estimator
The Estimator
class gives you an API for interaction with your model. Here’s a good overview from the official docs. It’s like a wrapper for a model which allows you to train, evaluate, and export the model as well as make inferences on new data. Usually you won’t be interacting directly with the base class tf.estimator.Estimator
, but rather with the Estimator
Classes which directly inherit from it, such as the DNNClassifier
Class. There are a whole set of pre-defined, easy to use Estimator
s which you can start working with out of the box, such as LinearRegressor
or BoostedTreesClassifier
You can instantiate an Estimator
object with minimal, readable code. If you decide to use the pre-existing Estimator
s from TensorFlow (i.e. “pre-canned” models), you can get started without digging any deeper than the __init__()
function! I’ve defined a 4-layer Deep Neural Network which accepts as input my input data (377-dimensional feature vectors) and predicts one of my 96 classes as such:
We’ve just defined a new DNN Classifier with an input layer (feature_columns
), four hidden layers (hidden_units
), and an output layer (n_classes
). Pretty easy, yeah?
You will probably agree that each of these three arguments is very clear expect for maybe the feature_columns
argument. You can think of “feature_columns” as being identical to “input_layer”. However, feature_columns
allows you to do a whole lot of pre-processing that a traditional input layer would never allow. The official documentation on feature_columns
is really good, and you should take a look. In a nutshell, think of these feature_columns
as a set of instructions for how to squeeze your raw data into the right shape for a neural net (or whatever model you’re training). Neural nets cannot take as input words, intergers, or anything else that isn’t a floating point number.
The feature_columns
API helps you not only get your data into floats, but it helps you find floats that actually make sense for your task at hand. You can easily encode words or categories as one-hot vectors, but one-hot vectors are not practical if you have a billion different words in your data. Instead of using one-hot vector feature_columns
, you can use the feature_column
type embedding_column
to find a lower-dimensional representation of your data. In the example above, I use the feature_column.numeric_column
because my input data is already encoded as floating point numbers.
Putting it All Together
Below is an example of minimal code you need for importing a tfrecords
file, training a model, and making predictions on new data.
parser_fn
The parser
function will be the most data-specific part of your code. Learn about how to make a good function here.
input_fn
This is an Estimator input function. It defines things like datasets and batches, and can perform operations such as shuffling. Both the dataset and dataset iterator are defined here. Read more about how to make a good input_fn
on the official docs.
Estimator
To get started fast, just choose an Estimator
from the available pre-made Estimators. For more detail on how to use pre-made Estimators
in general, check out the official docs.
If you want a custom architecture which is not pre-made, you can build your own Estimator.
Train & Eval Specs
Defining the training and evaluation routine for your model is easy with TrainSpec
and EvalSpec
. These two classes allow you to combine your model with your data along with instructions on how to combine them.
After you’ve defined the specs, you feed them to the specialized function tf.estimator.train_and_evaluate
which nicely handles all the heavy lifting. The Google Cloud folks wrote a very nice blog post on how to get best use of this function as well as your specs.
Inference
Finally, to make predictions on new data, just use the .predict()
method which is available to all Estimators
.
Conclusions
I hope you’ve found this post helpful!
Feel free to leave questions and comments below!