When using CNN, we often have to provide the dimension of images ahead of time. Note that I used singular, not plural “Dimensions”. That’s because most of the time, it requires us to input tensors of the same dimension.

However, we have a little bit of a problem here. It turned out, Keras actually does allow you to have variable dimensions of images as input. As long as you use a global polling layer before flattening the “features” generated by CNN layers. This is particularly useful to us. We work a lot on digital document processing; these scanned documents invariably have different dimensions. For example, some are in the landscape mode while others are in portrait mode. Another issue is the variety and diversity of these documents. For example, when processing mortgage documents; I found bank statements are often in letter-size while pay-stubs are much smaller. Historically, we need to resize all these images to the same dimension to feed into a model. But it comes at a very high cost. Because resizing causes distortion and add noises.

So, in essence, we need the model to be able to process images of different dimensions. There are plenty of examples you can find online on how to set up the model. So I won’t waste your time on that particular topic. However, a new problem emerges quickly. Keras does require that each batch of input have exactly the same dimension. So, most examples I saw would just give a batch size of one to simplify the input.

Well, I have a Titan RTX with 24G of GPU-RAM. Having a batch size of one just cannot justify the price tag of this GPU. So let’s build a data generator that can generate batches with the same dimension.

The idea is very simple: we group the images together and make sure each batch has the same dimension. Let’s go through a concrete example.

We are building a model to detect the rotation of scanned documents. To train the model, I have about 400,000 scanned US mortgage forms from Fannie Mae and Freddie Mac. They come with 352 different dimensions:

'826x900', '869x716', '678x842', '900x900', '900x868', '900x607',
'894x751', '845x900', '900x798', '703x860', '886x740', '900x771',
'900x775', '816x900', '890x890', '900x844', '893x900', '900x820',
'652x822', '740x886', '897x900', '900x780', '900x832', '900x776',
'900x789', '763x900', '855x900', '757x900', '900x879', '864x900'..........

Now, let’s talk about the data generator. In Tensorflow/Keras’ term, it’s called Sequence. So basically, to create a data generator, you need to create a new class that inherited from this class.

So let’s talk about this class. These two classes, you MUST implement:

  1. __len__: return the number of batches (not rows)
  2. __getitem__: return a single training batch. the parameter, idx is the index of a batch. Keras model seems to call __getitem__ randomly instead of incrementally.

In our case, we have 400K images on disk, it makes no sense to load them into memory all at once. Instead, we keep them on disk and load them on demand (In the past, I have also made them cache a portion of the images in memory to save IO time). So instead of initializing the class with an image folder, we provide a metadata map to indicate the details of each image. The metadata map effectively is a Pandas data frame (you can also use JSON or other suitable types). The idea is that every time the model requests a batch, we look up the metadata map, prepare the batch’s Xs and Ys, and return them in NumPy arrays.

Because the Pandas data frame is the centerpiece. We need to do some prep work. In the beginning, the frame is very simple, it has two columns:

  1. filename
  2. rotation (Y)

To better group images together by resolution, we add three more columns:

  1. width
  2. height
  3. wxh in string format

The width and height can be easily acquired by using the following code:

def getImageDimension(path): im = pil.Image.open(path) width, height = im.size return width, height

Note that, this is fast. Because PIL actually doesn’t actually load the image in memory at all.

wxh is simply a formatted string “{width}x{height}”. Why do we need this? There are a few reasons:

  1. It is the key to identify and sort images
  2. It helps preprocessing function to identify the variety of dimensions.

Because Pandas allow us to store information related to each row, we can effectively store the batch index on each row. So when __getitem__ is called, we only need to dynamically select the rows with that particular batch index. The key now is calculating the batch index for each row. Well, that’s actually fairly easy.

Let’s imagine that we have the following rows

file 1, 300x400 file 2, 300x400 file 3, 300x400 file 4, 300x400 file 5, 300x400 file 6, 310x300 file 7, 320x410 file 8, 320x410

and we have a batch size of 4

So, file 1–4 should have the batch index of 1, file 5 with the batch index 2, file 6 with the batch index 3, file 7–8 have the batch index of 4. Keras model, even though you provide the batch size to a data generator, Keras model couldn’t care any less about it as long as you can tuck everything into memory. In our case, we set this number as a variable so we can test out the max limit of our GPU.

Let’s take a look

First, there are several things going on when initializing the class.

  1. Save all parameters, which are useful for later
  2. Load the Pandas data frame from pickle
  3. Calculating the batch index for each row; note that we don’t need to do this if this data frame already contains the batch indices. In fact, the calculating is quite computationally intensive, you should only do it once unless you need to change the batch size.

The calculation is done within preprocess_metadata_map(). The algorithms are very simple

  1. Get all different resolutions among the images
  2. For each resolution, we get all the rows;
  3. We set the current row’s batch index to the current value unless it exceeds the batch size; in that case, we increment the batch index size
  4. We increment the batch size when we get to the next resolution

To verify the calculation is correct:

idxs = generator.data_map.batch_idx.unique()
for idx in idxs:
group = generator.data_map.loc[generator.data_map['batch_idx'] == idx]
assert len(group)<=generator.batch_size:

The total number of batches is simply the max value of batch_idx plus one (because the index is zero-based).

So the __len__ function simply returns a pre-calculated number

def __len__(self): return self.total_batches

The pre-calculation makes things a lot easier for __getitem__. At runtime, we only need to grab the rows by batch_idx, then load images into NumPy arrays and return a tuple.

The logic is quite simple:

  1. load image based on the metadata map
  2. add each image to a NumPy array
  3. add each Y (rotation in our case) to a Numpy array
  4. return both Numpy arrays

When loading images, I used grayscale and “bilinear” interpolation. I also rescaled the images so the data are normalized.

Note that in __getitem__, the NumPy array is initialized with a “batch size”, which is not the same as the maximal batch size for the data generator. The reason being, for certain resolutions, there may be fewer images. Keras model doesn’t care about the batch size unless you specify it. Therefore, it won’t be an issue.

We create two data generators, one for training and another for testing:

training_generator = TrainingDataGenerator("labels_training.pkl",
                     imageFolder,
                     batch_size=3,
                     shuffle_within_batch=True,
                     recalculate_batch_idx=True)
test_generator = TrainingDataGenerator("labels_testing.pkl",
                     imageFolder,
                     batch_size=3,
                     shuffle_within_batch=True,
                     recalculate_batch_idx=True)

Now we feed it to a model

The first layer is an Input layer without specifying any input shape except for the number of channels; in our case, we use grayscale, so it is 1

        keras.layers.Input(shape=(None, None, None, 1))
    )

Now, we feed the generators to the model

model.fit(training_generator, 
                        epochs=epoch, 
                        validation_data=test_generator, 
                        class_weight=class_weight,
                        callbacks=callbacks)

And you are on your way. Good luck learning deep!