Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 2)
In this part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.
The Architecture of Light-weight Deformable Registration Network
In practice, recent deformation networks follow an encoder-decoder architecture and use 3D convolution to progressively down-sample the image, and deconvolution (transposed convolution) to recover spatial resolution [1, 3]. However, this setup consumes a large number of parameters. Therefore, the built models are computationally expensive and time-consuming. To overcome this problem we design a new light-weight student network as illustrated in Figure 1.
In particular, the proposed light-weight network has four convolution layers and three deconvolution layers. Each convolutional layer has a bank of filters with strides of , followed by a ReLU activation function. The number of output channels of the convolutional layers starts with at the first layer, doubling at each subsequent layer, and ends up with . Skip connections between the convolutional layers and the deconvolutional layers are added to help refine the dense prediction. The subnetwork outputs a dense flow prediction field, i.e., a channels volume feature map with the same size as the input.
In comparison with the current state-of-the-art dense deformable registration network [3], the number of parameters of our proposed light-weight student network is reduced approximately times. In practice, this significant reduction may lead to an accuracy drop. Therefore, we propose a new Adversarial Learning with Distilling Knowledge algorithm to effectively leverage the teacher deformations to our introduced student network, making it light-weight but achieving competitive performance.
Adversarial Learning Algorithm with Distilling Knowledge
Our adversarial learning algorithm aims to improve the student network accuracy through the distilled teacher deformations extracted from the teacher network. The learning method comprises a deformation-based adversarial loss and its accompanying learning strategy (Algorithm 1).
Adversarial Loss. The loss function for the light-weight student network is a combination of the discrimination loss and the reconstruction loss . However, the forward and backward process through loss function is controlled by the Algorithm 1. In particular, the last deformation loss that outputs the final warped image can be written as:
where controls the contribution between and . Note that, the is only applied for the final warped image.
Discrimination Loss. In the student network the discrimination loss is computed in Equation below}.
where controls gradient penalty regularization. The joint deformation is computed from the teacher deformation and the predicted student deformation as follow:
where control the effect of the teacher deformation.
In Discrimination Loss, is the discriminator, formed by a neural network with learnable parameters . The details of is shown in Figure 3. In particular, consists of six convolutional layers, the first layer is and takes the deformation as input. is equaled to the scaled size of the input image. The second layer is . From the second layer to the last convolutional layer, each convolutional layer has a bank of filters with strides of , followed by a ReLU activation function except for the last layer which is followed by a sigmoid activation function. The number of output channels of the convolutional layers starts with at the second layer, doubling at each subsequent layer, and ends up with .
Basically, this is to inject the condition information with a matched tensor dimension and then leave the network learning useful features from the condition input. The output of the last neural layer is the mean feature of the discriminator, denoted as . Note that in the discrimination loss, a gradient penalty regularization is applied to deal with critic weight clipping which may lead to undesired behavior in training adversarial networks.
Reconstruction Loss. The reconstruction loss is an important part of a deformation estimator. Follow the VTN [3] baseline, the reconstruction loss is written as:
where
where is the correlation between two images and , is the covariance between them. denotes the cuboid (or grid) on which the input images are defined.
Learning Strategy. The forward and backward of the aforementioned is controlled by the adversarial learning strategy described in Algorithm 1.
In our deformable registration setup, the role of real data and attacking data is reversed when compared with the traditional adversarial learning strategy. In adversarial learning, the model uses unreal (generated) images as attacking data, while image labels are ground truths. However, in our deformable registration task, the model leverages the unreal (generated) deformations from the teacher as attacking data, while the image is the ground truth for the model to reconstruct the input information. As a consequence, the role of images and the labels are reversed in our setup. Since we want the information to be learned more from real data, the generator will need to be considered more frequently. Although the knowledge in the discriminator is used as attacking data, the information it supports is meaningful because the distilled information is inherited from the high-performed teacher model. With these characteristics of both the generator and discriminator, the light-weight student network is expected to learn more effectively and efficiently.
Reference
[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.
[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.
[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.
Open Source
🐱 Github: https://github.com/aioz-ai/LDR_ALDK