The Recurrent Inference Machine

For Accelerating MRI Reconstruction

By Kai Lønning, Patrick Putzky, Matthan A. Caan, Max Welling

This project is funded by the Canadian Institute for Advanced Research and is aimed at using  the Recurrent Inference Machine (RIM) [1] to solve the inverse problem of accelerating MRI reconstruction.

Middle image shows a full brain RIM reconstruction, starting from the 4 times under-sampled corruption on the left, attempting to recover the target on the right.

MR-signals are sampled in K-space, the frequency space of the image signal, and transformed back into image space through the inverse Fourier transform. Due to physiological constraints to the speed of which samples can be made in K-space, accelerating the MR-imaging process comes down to making fewer samples than necessary, then use an algorithm to retrieve the true image signal. The accelerated MRI corruption process is known, but the inverse process of reconstructing the true image is not, hence we are dealing with an inverse problem.

The RIM unrolled over three time-steps. The current best estimate \mathbf x_{t-1} is fed to the RIM, along with the current log-likelihood gradient \nabla_{\mathbf y\vert\mathbf x_{t-1}}, and a new update \Delta\mathbf x_{t-1} is produced to create the new best estimate \mathbf x_t = \mathbf x_{t-1} + \Delta\mathbf x_{t-1}. The RIM also maintains and updates an internal state \mathbf s, following the Recurrent Neural Network paradigm.

The RIM is a deep learning model designed as a general inverse problem solver. It takes the corrupted signal and the gradient of its log-likelihood as input, and uses this to generate an incremental update to the input, in order to approach an estimate of what the true signal looks like. This estimate and its log-likelihood gradient is then fed into the RIM again, from which another update is generated. This process repeats itself in a series of recurrent time-steps until a sufficiently good estimate of the true signal has been reached.

Each time-step in the Recurrent Inference Machine produces a new estimate, here shown to the left, from the 3x accelerated corruption until the 10th and final reconstruction. Target is in the middle, while the error (not to scale) is shown to the right.

Following the Recurrent Neural Network paradigm, the RIM also maintains a hidden state, which is updated as inference progresses through time-steps. The hidden state allows the RIM to steer the reconstruction process based on past states, even though the network is reusing the same weights as in the previous time-step. Essentially, the network acquires greater depth, while still retaining a low number of parameters that make iterative improvements on the signal estimates.

 

An example of a datapoint used during training. The RIM must recover the target image to the right, using the corrupted image to the left as a starting point.

Examples of sub-sampling patterns used during evaluation, here showing acceleration factors 2, 4 and 6.

Due to space limitations, the RIM is trained on small image patches of size 30×30, that are stacked together as separate training points in mini-batches. The patches are made from full MRI brain images, that have been cropped in image space and then projected into K-space where sub-sampling is done. For each mini-batch, the acceleration level is stochastically determined between 2 and 6. The sub-sampling pattern, or configuration of which frequencies to discard in K-space, is also randomly selected according to a Gaussian kernel, with more samples selected from low frequencies that determine the general shape of the object. 2% of the lowest frequencies are always sampled regardless of the sub-sampling pattern, as it is crucial for the RIM to receive the average image intensity. Perhaps surprisingly, the RIM is able to generalize from small image patches to full-sized MR-images, which are used during evaluation.

The RIM was trained using the Mean Square Error (MSE) averaged over all time-steps as a loss function. The resulting curves can be seen in the image below, showing the loss on the evaluation set as a function of training iterations for acceleration factors 4 and 6.

Loss curves on the evaluation set for acceleration factors 4 and 6. Real- and complex-valued networks shown in purple and cyan.

Due to the complex-valued nature of MR-data, it was proposed that the RIM should be implemented using Wirtinger-calculus, which is a way to optimize non-holomorphic (non-differentiable) functions of complex variables. As such, the RIM has been implemented for both real-calculus and Wirtinger-calculus, with corresponding loss curves shown above in purple and cyan, respectively. As illustrated, the benefits to using Wirtinger-calculus are thus far marginal, while network inference is slower due to complex algebra requiring more operations.

Heat map of the network loss as a function of training iteration and internal RIM reconstruction time-step.

The above image shows the network loss per time-step of the complex-valued RIM network across training iterations. As seen, the first couple of time-steps retain a high loss regardless of how long the model is trained for, indicating that processing updates to the log-likelihood gradient and the RIM’s internal state is a necessary precursor for proper reconstruction to take place in the ensuing time-steps. The heat map also illustrates that less and less is gained per time-step as the reconstruction reaches the final time-steps.

Until now, the main method of accelerating MRI reconstruction has been an algorithm known as Compressed Sensing (CS). CS exploits the fact that MR-images are compressible, but a suitable compression must be chosen before-hand, whereas the RIM finds its own compression implicitly. This is a recurring advantage to deep learning as it is applied to new areas. The job of carefully hand-engineering feature extractions can now be assigned to the neural network itself, which frequently leads to more useful extractions. The improvement of using the RIM over CS is visually perceptible in the reconstructions shown below.

The image shows reconstruction comparisons between Compressed Sensing and the Recurrent Inference Machine, for acceleration factors 4 and 6.

The benefits are also quantified in the whisker plot below, using the structural similarity index (SSIM) and MSE as metrics. The former metric is meant to pick up on perceptual differences between two images, yielding a value of 1 if the images are equal.

The average MSE and SSIM across acceleration factors 2, 3, 4, 5 and 6, for CS, and real- and complex-valued RIMs.

The higher the acceleration factor, the greater the improvement. The RIM is not only out-performing CS on average, it is also more robust against varying the sub-sampling pattern used. The next image shows whisker plots of the standard deviation across different sub-sampling patterns for the same acceleration factors.

The RIM is robust against changing the sub-sampling pattern used, whereas CS requires greater attention to picking an optimal pattern.

 

Both images show 4x accelerated Quantitative Susceptibility Map (QSM) reconstructions.

[1] : P. Putzky, M. Welling, Recurrent Inference Machines for Solving Inverse Problems 

Leave a Reply

Your email address will not be published. Required fields are marked *