6 posts tagged with "medical"

View All Tags

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 3)

In this part, we will show the effectivness and the ablation studies of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.

Dataset

As mentioned in [1], we train method on two types of scans: Liver CT scans and Brain MRI scans.

For Liver CT scans, we use 5 datasets:

  1. LiTS contains 131 liver segmentation scans.
  2. MSD has 70 liver tumor CT scans, 443 hepatic vessels scans, and 420 pancreatic tumor scans.
  3. BFH is a smaller dataset with 92 scans.
  4. SLIVER is a challenging dataset with 20 liver segmentation scans and annotated by 3 expert doctors.
  5. LSPIG (Liver Segmentation of Pigs) contains 17 pairs of CT scans from pigs, provided by the First Affiliated Hospital of Harbin Medical University.

For Brain MRI scans, we use 4 datasets: 1. ADNI contains 66 scans. 2. ABIDE contains 1287 scans. 3. ADHD contains 949 scans. 4. LPBA has 40 scans, each featuring a segmentation ground truth of 56 anatomical structures.

Baselines

We compare LDR ALDK method with the following recent deformable registration methods:

  • ANTs SyN and Elastix B-spline are methods that find an optimal transformation by iteratively update the parameters of the defined alignment.
  • VoxelMorph predicts a dense deformation in an unsupervised manner by using deconvolutional layers.
  • VTN is an end-to-end learning framework that uses convolutional neural networks to register 3D medical images, especially large displaced ones.
  • RCN is a recent recursive deep architecture that utilizes learnable cascade and performs progressive deformation for each warped image.

Results

Table 1 summarizes the overall performance, testing speed, and the number of parameters compared with recent state-of-the-art methods in the deformable registration task. The results clearly show that Light-weight Deformable Registration network (LDR) accompanied by Adversarial Learning with Distilling Knowledge (ALDK) algorithm significantly reduces the inference time and the number of parameters during the inference phase. Moreover, the method achieves competitive accuracy with the most recent highly performed but expensive networks, such as VTN or VoxelMorph. We notice that this improvement is consistent across all experiments on different datasets SLIVER, LiTS, LSPIG, and LPBA.

In particular, we observe that on the SLIVER dataset the Dice score of best model with 3 cascades (3-cas LDR + ALDK) is 0.3% less than the best result of 3-cas VTN + Affine, while inference speed is ?21 times faster on a CPU and the parameters used during inference is ~8 times smaller. Including benchmarking results in three other datasets, i.e., LiTS, LSPIG, and LPBA, light-weight model only trades off an average of 0.5% in Dice score and 1.25% in Jacc score for a significant gain of speed and a massive reduction in the number of parameters. We also notice that method is the only work that achieves the inference time of approximately 1s on a CPU. This makes method well suitable for deployment as it does not require expensive GPU hardware for inference.

Fig-1

Table 1: COMPARISON AMONG LDR ALDK MODEL WITH RECENT APPROACHES.

Ablation Study

Effectiveness of ALDK. Table 2 summarizes the effectiveness of Adversarial Learning with Distilling Knowledge (ALDK) when being integrated into the light-weight student network. Note that LDR without ALDK is trained using only the reconstruction loss in an unsupervised learning setup. From this table, we clearly see that ALDK algorithm improves the Dice score of the LDR tested in the SLIVER dataset by 3.4%, 4.0%, and 3.1% for 1-cas, 2-cas, and 3-cas setups, respectively. Additionally, using ALDK also increases the Jacc score by 5.2%, 4.9%, and 3.9% for 1-cas LDR, 2-cas LDR, and 3-cas LDR. These results verify the stability of adversarial learning algorithm in the inference phase, under the differences evaluation metrics, as well as the number of cascades setups. Furthermore, Table 2 also clearly shows the effectiveness and generalization of ALDK when being applied to the student network. Since the deformations extracted from the teacher are used only in the training period, adversarial learning algorithm fully maintains the speed and the number of parameters for the light-weight student network during inference. All results indicate that student network incorporated with the adversarial learning algorithm successfully achieves the performance goal, while maintaining the efficient computational cost of the light-weight setup.

Fig-2

Table 2: COMPARISON AMONG LDR ALDK MODEL WITH RECENT APPROACHES.

Accuracy vs. Complexity. Figure 1 demonstrates the experimental results from the SLIVER dataset between LDR + ALDK and the baseline VTN under multiple recursive cascades setup on both CPU and GPU. On the CPU (Figure 1-a), in terms of the 1-cascade setup, the Dice score of method is 0.2% less than VTN while the speed is ~15 times faster. The more the number of cascades is leveraged, the higher the speed gap between LDR + ALDK and the baseline VTN, e.g. the CPU speed gap is increased to ~21 times in a 3-cascades setup. We also observe the same effect on GPU (Figure 1-b), where method achieves slightly lower accuracy results than VTN, while clearly reducing the inference time. These results indicate that LDR + ALDK can work well with the teacher network to improve the accuracy while significantly reducing the inference time on both CPU and GPU in comparison with the baseline VTN network.

Fig-3

Figure 1:Plots of Dice score and Inference speed with respect to the number of cascades of the baseline Affine + VTN and LDR + ALDK. (a) for CPU speed and (b) for GPU speed. Note that results are reported for the SLIVER dataset; bars represent the CPU speed; lines represent the Dice score. All methods use an Intel Xeon E5-2690 v4 CPU and Nvidia GeForce GTX 1080 Ti GPU for inference.

Visualization

Figure 2 illustrates the visual comparison among 1-cas LDR, 1-cas LDR + ALDK, and the baseline 1-cas RCN. Five different moving images in a volume are selected to apply the registration to a chosen fixed image. It is important to note that though the sections of the warped segmentations can be less overlap with those of the fixed one, the segmentation intersection over union is computed for the volume and not the sections. In the segmented images in Figure 2, besides the matched area colored by white, we also marked the miss-matched areas by red for an easy-to-read purpose.

From Figure 2, we can see that the segmentation resutls of 1-cas LDR network without using ALDK (Figure 2-a) contains many miss-matched areas (denoted in red color). However, when we apply ALDK to the student network, the registration results are clearly improved (Figure 2-b). Overall, LDR + ALDK visualization results in Figure 2-b are competitive with the baseline RCN network (Figure 2-c). This visualization confirms that framework for deformable registration can achieve comparable results with the recent RCN network.

Fig-3

Figure 2:The visualization comparison between LDR (a), LDR + ALDK (b), and the baseline RCN (c). The left images are sections of the warped images; the right images are sections of the warped segmentation (white color represents the matched areas between warped image and fixed image, red color denotes the miss-matched areas). The segmentation visualization indicates that LDR + ALDK (b) method reduces the miss-matched areas of the student network LDR (a) significantly. Best viewed in color.

Reference

[1] Tran, Minh Q., et al. "Light-weight deformable registration using adversarial learning with distilling knowledge." IEEE Transactions on Medical Imaging, 2022.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge (Part 2)

In this part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge.

The Architecture of Light-weight Deformable Registration Network

In practice, recent deformation networks follow an encoder-decoder architecture and use 3D convolution to progressively down-sample the image, and deconvolution (transposed convolution) to recover spatial resolution [1, 3]. However, this setup consumes a large number of parameters. Therefore, the built models are computationally expensive and time-consuming. To overcome this problem we design a new light-weight student network as illustrated in Figure 1.

In particular, the proposed light-weight network has four convolution layers and three deconvolution layers. Each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function. The number of output channels of the convolutional layers starts with 1616 at the first layer, doubling at each subsequent layer, and ends up with 128128. Skip connections between the convolutional layers and the deconvolutional layers are added to help refine the dense prediction. The subnetwork outputs a dense flow prediction field, i.e., a 33 channels volume feature map with the same size as the input.

In comparison with the current state-of-the-art dense deformable registration network [3], the number of parameters of our proposed light-weight student network is reduced approximately 1010 times. In practice, this significant reduction may lead to an accuracy drop. Therefore, we propose a new Adversarial Learning with Distilling Knowledge algorithm to effectively leverage the teacher deformations ϕt\phi_t to our introduced student network, making it light-weight but achieving competitive performance.

Fig-1

Figure 1: The structure of Light-weight Deformable Registration student network. The number of channels is annotated above the layer. Curved arrows represent skip paths (layers connected by an arrow are concatenated before transposed convolution). Smaller canvas means lower spatial resolution (Source).

Adversarial Learning Algorithm with Distilling Knowledge

Our adversarial learning algorithm aims to improve the student network accuracy through the distilled teacher deformations extracted from the teacher network. The learning method comprises a deformation-based adversarial loss Ladv\mathcal{L}_{adv} and its accompanying learning strategy (Algorithm 1).

Fig-2

Figure 2: Adversarial Learning Strategy(Source).

Adversarial Loss. The loss function for the light-weight student network is a combination of the discrimination loss ldisl_{dis} and the reconstruction loss lresl_{res}. However, the forward and backward process through loss function is controlled by the Algorithm 1. In particular, the last deformation loss Ladv\mathcal{L}_{adv} that outputs the final warped image can be written as:

Ladv=γlrec+(1γ)ldis\mathcal{L}_{adv} = \gamma l_{rec} + (1 - \gamma) l_{dis}

where γ\gamma controls the contribution between lrecl_{rec} and ldisl_{dis}. Note that, the Ladv\mathcal{L}_{adv} is only applied for the final warped image.

Discrimination Loss. In the student network the discrimination loss is computed in Equation below}.

ldis=Dθ(ϕs)Dθ(ϕt)22+λ(ϕ^sDθ(ϕ^s)21)2l_{{dis}} = \left\lVert D_\mathbf{\theta}(\phi_{s}) - D_\mathbf{\theta}(\phi_{t}) \right\lVert_2^{2} + \lambda\bigg(\left\lVert \nabla_{\hat\phi_{s}}D_\mathbf{\theta}(\hat\phi_{s}) \right\lVert_2 - 1\bigg)^{2}

where λ\lambda controls gradient penalty regularization. The joint deformation ϕ^s\hat\phi_{s} is computed from the teacher deformation ϕt\phi_{t} and the predicted student deformation ϕs\phi_{s} as follow:

ϕ^s=βϕt+(1β)ϕs\hat\phi_{s} = \beta \phi_{t} + (1 - \beta) \phi_{s}

where β\beta control the effect of the teacher deformation.

In Discrimination Loss, DθD_\mathbf{\theta} is the discriminator, formed by a neural network with learnable parameters θ{\theta}. The details of DθD_\mathbf{\theta} is shown in Figure 3. In particular, DθD_\mathbf{\theta} consists of six 3D3D convolutional layers, the first layer is 128×128×128×3128 \times 128 \times 128 \times 3 and takes the c×c×c×1c \times c \times c \times 1 deformation as input. cc is equaled to the scaled size of the input image. The second layer is 64×64×64×1664 \times 64 \times 64 \times 16. From the second layer to the last convolutional layer, each convolutional layer has a bank of 4×4×44 \times 4 \times 4 filters with strides of 2×2×22 \times 2 \times 2, followed by a ReLU activation function except for the last layer which is followed by a sigmoid activation function. The number of output channels of the convolutional layers starts with 1616 at the second layer, doubling at each subsequent layer, and ends up with 256256.

Basically, this is to inject the condition information with a matched tensor dimension and then leave the network learning useful features from the condition input. The output of the last neural layer is the mean feature of the discriminator, denoted as MM. Note that in the discrimination loss, a gradient penalty regularization is applied to deal with critic weight clipping which may lead to undesired behavior in training adversarial networks.

Fig-3

Figure 3: The structure of the discriminator DθD_\mathbf{\theta} used in the Discrimination Loss (ldisl_{dis}) of our Adversarial Learning with Distilling Knowledge algorithm (Source).

Reconstruction Loss. The reconstruction loss lrecl_{rec} is an important part of a deformation estimator. Follow the VTN [3] baseline, the reconstruction loss is written as:

lrec(Imh,If)=1CorrCoef[Imh,If]l_{{rec}} (\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f) = 1 - CorrCoef [\textbf{\textit{I}}_m^h,\textbf{\textit{I}}_f]

where

CorrCoef[I1,I2]=Cov[I1,I2]Cov[I1,I1]Cov[I2,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_2]}{\sqrt{Cov[\textbf{\textit{I}}_1,\textbf{\textit{I}}_1]Cov[\textbf{\textit{I}}_2,\textbf{\textit{I}}_2]}}
Cov[I1,I2]=1ωxωI1(x)I2(x)1ω2xωI1(x)yωI2(y)Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] = \frac{1}{|\omega|}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\textbf{\textit{I}}_2(x) - \frac{1}{|\omega|^{2}}\sum_{x \in \omega} \textbf{\textit{I}}_1(x)\sum_{y \in \omega}\textbf{\textit{I}}_2(y)

where CorrCoef[I1,I2]CorrCoef[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the correlation between two images I1\textbf{\textit{I}}_1 and I2\textbf{\textit{I}}_2, Cov[I1,I2]Cov[\textbf{\textit{I}}_1, \textbf{\textit{I}}_2] is the covariance between them. ω\omega denotes the cuboid (or grid) on which the input images are defined.

Learning Strategy. The forward and backward of the aforementioned Ladv\mathcal{L}_{adv} is controlled by the adversarial learning strategy described in Algorithm 1.

In our deformable registration setup, the role of real data and attacking data is reversed when compared with the traditional adversarial learning strategy. In adversarial learning, the model uses unreal (generated) images as attacking data, while image labels are ground truths. However, in our deformable registration task, the model leverages the unreal (generated) deformations from the teacher as attacking data, while the image is the ground truth for the model to reconstruct the input information. As a consequence, the role of images and the labels are reversed in our setup. Since we want the information to be learned more from real data, the generator will need to be considered more frequently. Although the knowledge in the discriminator is used as attacking data, the information it supports is meaningful because the distilled information is inherited from the high-performed teacher model. With these characteristics of both the generator and discriminator, the light-weight student network is expected to learn more effectively and efficiently.

Reference

[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge

Introduction: Medical image registration

Medical image registration is the process of systematically placing separate medical images in a common frame of reference so that the information they contain can be effectively integrated or compared. Applications of image registration include combining images of the same subject from different modalities, aligning temporal sequences of images to compensate for the motion of the subject between scans, aligning images from multiple subjects in cohort studies, or navigating with image guidance during interventions. Since many organs do deform substantially while being scanned, the rigid assumption can be violated as a result of scanner-induced geometrical distortions that differ between images. Therefore, performing deformable registration is an essential step in many medical procedures.

Previous Studies, Remaining Challenges, and Motivation

Recently, learning-based methods have become popular to tackle the problem of deformable registration. These methods can be split into two groups: (i) supervised methods that rely on the dense ground-truth flows obtained by either traditional algorithms or simulating intra-subject deformations. Although these works achieve state-of-the-art performance, they require a large amount of manually labeled training data, which are expensive to obtain; and (ii) unsupervised learning methods that use a similarity measurement between the moving and the fixed image to utilize a large amount of unlabelled data. These unsupervised methods achieve competitive results in comparison with supervised methods. However, their deformations are reconstructed without the direct ground-truth guidance, hence leading to the limitation of leveraging learnable information. Furthermore, recent unsupervised methods all share an issue of great complexity as the network parameters increase significantly when multiple progressive cascades are taken into account. This leads to the fact that these works can not achieve real-time performance during inference while requiring intensively computational resources when deploying.

In practice, there are many scenarios when medical image registration are needed to be fast - consider matching preoperative and intra-operative images during surgery, interactive change detection of CT or MRI data for a radiologist, deformation compensation or 3D alignment of large histological slices for a pathologist, or processing large amounts of images from high-throughput imaging methods. Besides, in many image-guided robotic interventions, performing real-time deformable registration is an essential step to register the images and deal with organs that deform substantially. Economically, the development of a CPU-friendly solution for deformable registration will significantly reduce the instrument costs equipped for the operating theatre, as it does not require GPU or cloud-based computing servers, which are costly and consume much more power than CPU. This will benefit patients in low- and middle-income countries, where they face limitations in local equipment, personnel expertise, and budget constraints infrastructure. Therefore, design an efficient model which is fast and accurate for deformable registration is a crucial task and worth for study in order to improve a variety of surgical interventions.

Contribution

Deformable registration is a crucial step in many medical procedures such as image-guided surgery and radiation therapy. Most recent learning-based methods focus on improving the accuracy by optimizing the non-linear spatial correspondence between the input images. Therefore, these methods are computationally expensive and require modern graphic cards for real-time deployment. Thus, we introduce a new Light-weight Deformable Registration network that significantly reduces the computational cost while achieving competitive accuracy (Fig.1). In particular, we propose a new adversarial learning with distilling knowledge algorithm that successfully leverages meaningful information from the effective but expensive teacher network to the student network. We design the student network such as it is light-weight and well suitable for deployment on a typical CPU. The extensively experimental results on different public datasets show that our proposed method achieves state-of-the-art accuracy while significantly faster than recent methods. We further show that the use of our adversarial learning algorithm is essential for a time-efficiency deformable registration method.

Fig-1

(a)
(b)
Figure 1: Comparison between typical deep learning-based methods for deformable registration (a) and our approach using adversarial learning with distilling knowledge for deformable registration (b). In our work, the expensive Teacher Network is used only in training; the Student Network is light-weight and inherits helpful knowledge from the Teacher Network via our Adversarial Learning algorithm. Therefore, the Student Network has high inference speed, while achieving competitive accuracy (Source).

Methodology

Method overview

We describe our method for Light-weight Deformable Registration using Adversarial Learning with Distilling Knowledge. Our method is composed of three main components: (i)) a Knowledge Distillation module which extracts meaningful deformations ϕt\bm{\phi_t} from the Teacher Network; (ii) a Light-weight Deformable Registration (LDR) module which outputs a high-speed Student Network; and (iii) an Adversarial Learning with Distilling Knowledge (ALDK) algorithm which effectively leverages teacher deformations ϕt\bm{\phi}_t to the student deformations. An overview of our proposed deformable registration method can be found in Fig.2.

Fig-2

Figure 2: An overview of our proposed Light-weight Deformable Registration (LDR) method using Adversarial Learning with Distilling Knowledge (ALDK). Firstly, by using knowledge distillation, we extract the deformations from the Teacher Network as meaningful ground-truths. Secondly, we design a light-weight student network, which has competitive speed. Finally, We employ the Adversarial Learning with Distilling Knowledge algorithm to effectively transfer the meaningful knowledge of distilled deformations from the Teacher Network to the Student Network (Source).

Since the content may over-length, in this part, we introduce the background theory for Deformable Registration and Knowledge Distillation for Deformation. In the next part, we will introduce the Architecture of Light-weight Deformable Registration Network and Adversarial Learning Algorithm with Distilling Knowledge. In the final part, we will introduce the effectiveness of the method in comparison with recent states of the arts and detailed analysis.

Background: Deformable Registration

We follow RCN [1] to define deformable registration task recursively using multiple cascades. Let Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_f denote the moving image and the fixed image respectively, both defined over dd-dimensional space Ω\bm{\Omega}. A deformation is a mapping ϕ:ΩΩ\bm{\phi} : \bm{\Omega} \rightarrow \bm{\Omega}. A reasonable deformation should be continuously varying and prevented from folding. The deformable registration task is to construct a flow prediction function F\textbf{F} which takes Im,If\textbf{\textit{I}}_m, \textbf{\textit{I}}_ f as inputs and predicts a dense deformation ϕ\bm{\phi} that aligns Im\textbf{\textit{I}}_m to If\textbf{\textit{I}}_f using a warp operator \circ as follows:

F(n)(Im(n1),If)=ϕ(n)F(n1)(ϕ(n1)Im(n2),If)\textbf{F}^{(n)}(\textbf{\textit{I}}^{(n-1)}_m,\textbf{\textit{I}}_f)=\phi^{(n)} \circ \textbf{F}^{(n-1)}(\phi^{(n-1)} \circ \textbf{\textit{I}}^{(n-2)}_m,\textbf{\textit{I}}_f)

where F(n1)\textbf{F}^{(n-1)} is the same as F(n)\textbf{F}^{(n)}, but in a different flow prediction function. Assuming for nn cascades in total, the final output is a composition of all predicted deformations, i.e.,

F(Im,If)=ϕ(n)...ϕ(1),\textbf{F}(\textbf{\textit{I}}_m, \textbf{\textit{I}}_f)=\phi^{(n)} \circ...\circ \phi^{(1)},

and the final warped image is constructed by

Im(n)=F(Im,If)Im\textbf{\textit{I}}_{m}^{(n)}=\textbf{F}(\textbf{\textit{I}}_m,\textbf{\textit{I}}_f) \circ \textbf{\textit{I}}_m

In general, previous Equations form the hypothesis function F\mathcal{F} under the learnable parameter W\mathbf{W},

F(Im,If,W)=(vϕ,Im(n))\mathcal{F}(\textbf{\textit{I}}_{m}, \textbf{\textit{I}}_f, \mathbf{W}) = (\mathbf{v}_{\phi}, \textbf{\textit{I}}_m^{(n)})

where vϕ=[ϕ(1),ϕ(2),...,ϕ(k),...,ϕ(n)]\mathbf{v}_{\phi} = [\bm{\phi}^{(1)}, \bm{\phi}^{(2)}, ..., \bm{\phi}^{(k)},..., \bm{\phi}^{(n)}] is a vector containing predicted deformations of all cascades. Each deformation ϕ(k)\bm{\phi}^{(k)} can be computed as

ϕ(k)=F(k)(Im(k1),If,Wϕ(k))\bm{\phi}^{(k)} = {\mathcal{F}}^{(k)}\left(\textbf{\textit{I}}_{m}^{(k-1)}, \textbf{\textit{I}}_f, \mathbf{W}_{\phi^{(k)}}\right)

To estimate and achieve a good deformation, different networks are introduced to define and optimize the learnable parameter W\mathbf{W}.

Knowledge Distillation for Deformation

Knowledge distillation is the process of transferring knowledge from a cumbersome model (teacher model) to a distilled model (student model). The popular way to achieve this goal is to train the student model on a transfer set using a soft target distribution produced by the teacher model.

Different from the typical knowledge distillation methods that target the output softmax of neural networks as the knowledge, in the deformable registration task, we leverage the teacher deformation ϕt\bm{\phi}_t as the transferred knowledge. As discussed in [2], teacher networks are usually high-performed networks with good accuracy. Therefore, our goal is to leverage the current state-of-the-art Recursive Cascaded Networks (RCN) [1] as the teacher network for extracting meaningful deformations to the student network. The RCN network contains an affine transformation and a large number of dense deformable registration sub-networks designed by VTN [3]. Although the teacher network has expensive computational costs, it is only applied during the training and will not be used during the inference.

Reference

[1] S. Zhao, Y. Dong, E. I. Chang, Y. Xu, et al., Recursive cascaded networks for unsupervised medical image registration, in ICCV, 2019.

[2] G. Hinton, O. Vinyals, and J. Dean, Distilling the knowledge in a neural network, ArXiv, 2015.

[3] S. Zhao, T. Lau, J. Luo, I. Eric, C. Chang, and Y. Xu, Unsupervised 3d end-to-end medical image registration with volume tweening network, IEEE J-BHI, 2019.

Open Source

🐱 Github: https://github.com/aioz-ai/LDR_ALDK

Multiple Meta-model Quantifying for Medical Visual Question Answering

Motivation

A medical Visual Question Answering (VQA) system can provide meaningful references for both doctors and patients during the treatment process. Extracting image features is one of the most important steps in a medical VQA framework which outputs essential information to predict answers. Transfer learning, in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet, is a popular way to initialize the feature extraction process. However, due to the difference in visual concepts between ImageNet images and medical images, finetuning process is not sufficient. Recently, Model Agnostic Meta-Learning (MAML) has been introduced to overcome the aforementioned problem by learning meta-weights that quickly adapt to visual concepts. However, MAML is heavily impacted by the meta-annotation phase for all images in the medical dataset. Different from normal images, transfer learning in medical images is more challenging due to:

  • (i) noisy labels may occur when labeling images in an unsupervised manner;
  • (ii) high-level semantic labels cause uncertainty during learning;
  • (iii) difficulty in scaling up the process to all unlabeled images in medical datasets.

Overcoming Data Limitation in Medical Visual Question Answering

What are the difficulties when dealing with Medical VQA task?

Visual Question Answering (VQA) aims to provide a correct answer to a given question such that the answer is consistent with the visual content of a given image.

In medical domain, VQA could benefit both doctors and patients. For example, doctors could use answers provided by VQA system as support materials in decision making, while patients could ask VQA questions related to their medical images for better understanding their health.

Fig-1

Figure 1: An example of Medical VQA (Source).

However, one major problem with medical VQA is the lack of large scale labeled training data which usually requires huge efforts to build.

  • The first attempt for building the dataset for medical VQA is by ImageCLEF-Med. In this, images were automatically captured from PubMed Central articles. The questions and answers were automatically generated from corresponding captions of images. By that construction, the data has high noisy level, i.e., the dataset includes many images that are not useful for direct patient care and it also contains questions that do not make any sense.
  • Recently, the first manually constructed VQA-RAD dataset for medical VQA task is released. Unfortunately, it contains only 315 images, which prevents to directly apply the powerful deep learning models for the VQA problem. One may think about the use of transfer learning in which the pretrained deep learning models that are trained on the large scale labeled dataset such as ImageNet are used for finetuning on the medical VQA. However, due to difference in visual concepts between ImageNet images and medical images, finetuning with very few medical images is not sufficient.

Therefore it is necessary to develop a new VQA framework that can improve the accuracy while still only needs a small labeled training data.

The motivation for our approach to overcome the data limitation of medical VQA comes from two observations:

  • Firstly, we observe that there are large scale unlabeled medical images available. These images are from same domain with medical VQA images. Hence if we train an unsupervised deep learning model using these unlabeled images, the trained weights may be easier to be adapted to the medical VQA problem than the pretrained weights on ImageNet images.
  • Another observation is that although the labeled dataset VQA-RAD is primarily designed for VQA, by spending a little effort, we can extract the new class labels for that dataset. The new class labels allow us to apply the recent meta-learning technique for learning meta-weights, that can be quickly adapted to the VQA problem later.

Methodology

The proposed medical VQA framework is presented in Figure 2. In our framework, the image feature extraction component is initialized by pretrained weights from MAML and CDAE. After that, the VQA framework will be finetuned in an end-to-end manner on the medical VQA data. In the following sections, we detail the architectures of MAML, CDAE, and our framework.

Fig-2

Figure 2: The proposed medical VQA. The image feature extraction is denoted as 'Mixture of Enhanced Visual Features (MEVF)' and is marked with the red dashed box. The weights of MEVF are intialized by MAML and CDAE (Source).

Model-Agnostic Meta-Learning -- MAML

The MAML model consists of four 3×33\times3 convolutional layers with stride 22 and is ended with a mean pooling layer; each convolutional layer has 6464 filters and is followed by a ReLu layer.

We create the dataset for training MAML by manually reviewing around three thousand question-answer pairs from the training set of VQA-RAD dataset. In our annotation process, images are split into three parts based on its body part labels (head, chest, abdomen). Images from each body part are further divided into three subcategories based on the interpretation from the question-answer pairs corresponding to the images. These subcategories are: 1. normal images in which no pathology is found. 2. abnormal present images in which there are the existence of fluid, air, mass, or tumor. 3. abnormal organ images in which the organs are large in size or in wrong position.

Thus, all the images are categorized into 9 classes:

| head normal | head abnormal present | head abnormal organ |
| chest normal | chest abnormal organ | chest abnormal present |
| abdominal normal | abdominal abnormal organ | abdominal abnormal present |

For every iteration of MAML training (line 3 in Alg.1), 5 tasks are sampled per iteration. For each task, we randomly select 3 classes (from 9 classes). For each class, we randomly select 6 images in which 3 images are used for updating task models and the remaining 3 images are used for updating meta-model.

Alg-1

Denoising Auto Encoder -- CDAE

The encoder maps an image xx', which is the noisy version of the original image xx, to a latent representation zz which retains useful amount of information. The decoder transforms zz to the output yy. The training algorithm aims to minimize the reconstruction error between yy and the original image xx as follows

Lrec=xy22L_{rec} = \left \| x-y \right \|_2^2

In our design, the encoder is a stack of convolutional layers; each of them is followed by a max pooling layer. The decoder is a stack of deconvolutional and convolutional layers. The noisy version xx' is achieved by adding Gaussian noise to the original image xx.

To train CDAE, we collect 11,77911,779 unlabeled images available online which are brain MRI images, chest X-ray images and CT abdominal images. The dataset is split into train set with 9,4239,423 images and test set with 2,3562,356 images. We use Gaussian noise to corrupt the input images before feeding them to the encoder.

Our VQA framework

After training MAML and CDAE, we use their trained weights to initialize the MEVF image feature extraction component in the VQA framework. We then finetune the whole VQA model using the training set of VQA-RAD dataset.

To train the proposed model, we introduce a multi-task loss func-tion to incorporate the effectiveness of the CDAE to VQA. Formally, our lossfunction is defined as follows:

L=α1Lvqa+α2LrecL = \alpha_1 L_{vqa} + \alpha_2 L_{rec}

where LvqaL_{vqa} is a Cross Entropy loss for VQA classification and LrecL_{rec} stands for the reconstruction loss of CDAE . The whole VQA model is finetuned in an end-to-end manner.

Results

Tab-1

Table 1: VQA results on VQA-RAD test set. All reference methods differ at the image feature extraction component. Other components are similar. The Stacked Attention Network (SAN) is used as the attention mechanism in all methods (Source).

Table 1 presents VQA accuracy in both VQA-RAD open-ended and close-ended questions on the test set. The results show that for both MAML and CDAE, by firstly pretraining then finetuning, the finetuning significantly improves the performance over the training from scratch using only VQA-RAD.

In addition, the results also show that our pretraining and finetuning of MAML and CDAE give better performance than the finetuning of VGG-16 which is pretrained on the ImageNet dataset. Our proposed image feature extraction MEVF which leverages both pretrained weights of MAML and CDAE, then finetuning them give the best performance. This confirms the effectiveness of the proposed MEVF for dealing with the limitation of labeled training data for medical VQA.

Tab-2

Table 2: Performance comparison on VQA-RAD test set (Source).

Table 2 presents comparative results between methods. Note that for the image feature extraction, the baselines use the pretrained models (VGG or ResNet) that have been trained on ImageNet and then finetune on the VQA-RAD dataset. For the question feature extraction, all baselines and our framework use the same pretrained models (i.e., Glove) and finetuning on VQA-RAD. The results show that when BAN or SAN is used as the attention mechanism in our framework, it significantly outperforms the baseline frameworks BAN and SAN. Our best setting, i.e. the one with BAN as the attention, achieves the state-of-the-art results and it significantly outperforms the best baseline framework BAN, i.e., the improvements are 16.3%16.3\% and 8.6%8.6\% on open-ended and close-ended VQA, respectively.

Conclusion

In this paper, we proposed a novel medical VQA framework that leverages the meta-learning MAML and denoising auto-encoder CDAE for image feature extraction in order to overcome the limitation of labeled training data. Specifically, CDAE helps to leverage information from the large scale unlabeled images, while MAML helps to learn meta-weights that can be quickly adapted to the VQA problem. We establish new state-of-the-art results on VQA-RAD dataset for both close-ended and open-ended questions.

Open Source

🐱 Github: https://github.com/aioz-ai/MICCAI19-MedVQA

Data Augmentation for Colon Polyp Detection: A systematic Study

Colorectal cancer (CRC)♋, also known as bowel cancer or colon cancer, is a cancer development from the colon or rectum called a polyp. Detecting polyps is a common approach in screening colonoscopies to prevent CRC at an early stage. Early colon polyp detection from medical images is still an unsolved problem due to the considerable variation of polyps in shape, texture, size, color, illumination, and the lack of publicly annotated datasets. At AIOZ, we adopt a recently proposed auto-augmentation method for polyp detection. We also conduct a systematic study on the performance of different data augmentation methods for colon polyp detection. The experimental results show that the auto-augmentation achieves the best performance comparing to other augmentation strategies.

Introduction

Colorectal cancer (CRC) is the third-largest cause of worldwide cancer deaths in men and the second cause in women, with the number of patients, died each year up to 700,000 [1]. Detection and removal of colon polyps at an early stage will reduce the mortality from CRC. There are several methods for colon screening, such as CT colonography or wireless capsule endoscopy, but the gold standard is colonoscopy [2].

The colonoscopy is performed by an experienced doctor who uses a colonoscope to screen and scan for abnormalities such as intestinal signs, symptoms, colon cancer, and polyps. Abnormal polyps can be removed, and small amounts of tissue can be detached for analysis during the colonoscopy. However, the most crucial drawback of colonoscopy is polyp miss rate, especially with polyp more diminutive than10mm. Several factors cause the miss rate. They are both subjective factors such as bowel preparation, the specific choice of an endoscope, video processor, clinician skill, and objective factors such as polyp appearance and camera movement condition. For these reasons, automatic polyp detection is a potential approach to assist clinicians in improving the sensitivity of the diagnosis.

Previous research shows that automatic polyp detection using deep learning-based methods outperforms hand-craft-based methods demonstrated by both top two results in the MICCAI 2015 challenge [3]. For deep learning-based approaches and model architectures, data augmentation is also a critical factor in making significant improvements due to the lack of annotated data. The recent work [4]shows that learning an optimal policy from data for auto augmentation instead of hand-crafted defining data augmentation strategies can generalize objects better. Thus, studying auto augmentation for polyp detection problems is necessary. In this research, we adapt Faster R-CNN [5] together withAutoAugment [6] to detect polyp from colonoscopy video frames. Besides, we also evaluate traditional data augmentation [7] to see the effectiveness of different augmentation strategies.

Methodology

1. Polyp Detector

Thanks to the power of deep learning, recent works [5, 12,11] show that deep-based detection methods give impressive detection performance. In this work, we use the Faster RCNN object detector [5] with Resnet101 [13] backbone pre-trained on COCO dataset. Our experiments show that this architecture gives a competitive performance on the Polyp detection problem. The experimental setting for the detector is set as follows. The network is trained using stochastic gradient descent (SGD)with 0.9 momentum; learning rate with initial value is set to3e-4 and will be decreased to 3e-5 from the iteration 900k.The number of anchor boxes per location used in our model is 12 (4 scales, i.e.,64×64, 128×128, 256×256, 512×512 and 3 aspect ratios, i.e.,1 : 2,1 : 1,2 : 1).

2. Data Augmentation

autoaugment_figure_small

Fig. 1. Example of applying learned augmentation policies tocolonoscopy image.

Data augmentation can be split into two types: self-defined data augmentation (a.k.a traditional augmentation) and auto augmentation [6]. In this study, we adopt an automated data augmentation approach for object detection, i.e., Auto-augment [6], which finds optimal data augmentation policies during training. In Auto-augment, an augmentation policy consists of several sub-policies; each sub-policy consists of two operations. Each operation is an image transformation containing two parameters: probability and the magnitude of the shift. There are three types of transformations used in Auto-augment for object detection [4], which are

  • Color operations: distort color channels without impacting the locations of the bounding boxes
  • Geometric operations: geometrically distort the image, which correspondingly alters the location and size of the bounding box annotations
  • Bounding box operations: only distort the pixel content contained within the bounding box annotations. One of the essential conclusions in [4] is that the learned policy found on COCO can be directly applied to other detection datasets and models to improve predictive accuracy. Hence, in this study, we apply the learned policies from [4]to augment data when training the detector in Sec. 3.1. The learned policed we use for training our detector is summarized in Table 1. In Table 1, each operation is a triple which describes the transformation, the probability, and the magnitude of the transformations. Due to the space limitation, we refer the reader to [4] for detail on the descriptions of trans-formation. Fig. 1 showed augmented examples when applying the learned augmentation policy on a polyp image from the training dataset.

tab-1

Table 1. Sub-policies and operations used in our experiment.

In addition to auto-augmentation, we also investigate the effect of traditional augmentation and the combination of traditional and automatic augmentation. We randomly apply several transformations to the image for traditional data augmentation, such as rotation, mirroring, sheering, translation, and zoom. We propose different strategies to combine those data augmentation types, i.e., (1) firstly, the detector is trained with Auto-augment; after that, it is trained with the traditional data augmentation; (2) training with the traditional augmentation, then with auto augmentation; (3) training with AutoAugmenton the original data and the data generated by traditional augmentation. All augmentation strategies are evaluated with the same model architecture and training configuration. This allows us to explore which data augmentation method is suitable for the polyp detection problem.

Experiments

We use CVC-ClinicDB [14] for training and ETIS-Larib[15] for testing. This allows us to make a fair comparison with MICCAI2015 challenge results which are reported on the same dataset. The CVC-CLINIC database contains 612polyp image frames of 31 unique polyps from 31 different colonoscopy videos. The ETIS-LARIB dataset contains 196high resolution image frames of 44 different polyps.

fp_figure

Fig. 2. Examples of false positive detection on testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.

fn_figure

Fig. 3. Examples of false-negative detection on the testing dataset. Green boxes and blues boxes are ground truths and predictions, respectively.

Fig. 2 and Fig. 3 visualize several failed results from our model in the testing dataset in which the blue boxes are the predicted locations, and green boxes are ground truths. These false-positive samples (Fig. 2) caused by a shortcoming in bowel preparation (i.e., leftovers of food and fluid in the colon), while false negative (Fig. 3) samples are caused by the variations of polyp type and appearance (i.e., small polyp, flat polyp, similarities of polyp and colon vein)

tab-3

Table 2. Comparison among traditional data augmentation (TDA), auto augmentation (AA) and their combinations.
Table 2 shows the comparative results between different augmentation strategies. The results show that the third combination method (AA-TDA-3) achieves higher performance in Precision than AutoAugment, i.e., 75.90% and 74.51%, respectively. However, overall, Auto-augment (AA) achieves the best results because of its performance in covering polyp miss rate (i.e., 152) with an acceptable false-positive rate (i.e., 52). The competitive performance of auto augmentation (AA) confirms the transferable learned data augmentation policies on the COCO dataset [4].

tab-4

Table 3. Comparative results between our model and the state of the art.
Table 3 presents the comparative results between the auto augmentation in our model and other state-of-the-art results. Among compared methods, CUMED, OUR, and UNS- UCLAN are end-to-end deep learning-based approaches. The results show that compared to methods from MICCAI challenge [3], auto augmentation achieves better performance on all metrics. Comparing to the recent method [10], auto augmentation also achieves better performance on all metrics but FP. These results confirm the effectiveness of auto augmentation for polyp detection problems.

Conclusion

This study adopts a deep learning-based object detection method with auto data augmentation for polyp detection problems. Different augmentation strategies are evaluated. The experimental results show that the learned auto augmentation policies learned from the general object detection dataset are well transferred to the polyp detection problem. Although auto augmentation achieves competitive results, it still has a high FP compared to the state of the art. This weakness can be improved by several post-processing, such as false-positive learning.

Open Source

🍅 Github: https://github.com/aioz-ai/polyp-detection

🍓 Blog post: https://ai.aioz.io/blog/polyp-detection

Acknowledgements

This research was conducted by Phong Nguyen, Quang Tran, Erman Tjiputra, and Toan Do. We’d like to give special thanks to the other AIOZ AI team members for their supports and feedbacks.

🎉 All the above contributions were incredibly enabling for this research. 🎉

Reference

[1] Hamidreza Sadeghi Gandomani, Mohammad Aghajani, et al., “Colorectal cancer in the world: incidence, mortality and risk factors,”Biomedical Research and Therapy, 2017.

[2] Florence B ́enard, Alan N Barkun, et al., “Systematic review of colorectal cancer screening guidelines for average-risk adults: Summarizing the current global recommendations,”World journal of gastroenterology, 2018.

[3] Jorge Bernal, Nima Tajkbaksh, et al., “Comparative validation of polyp detection methods in video colonoscopy: results from the miccai 2015 endoscopic vision challenge,”IEEE Transactions on Medical Imaging, pp. 1231–1249, 2017.

[4] Barret Zoph, Ekin D Cubuk, et al., “Learning data augmentation strategies for object detection,”arXiv, 2019.

[5] Shaoqing Ren, Kaiming He, Ross Girshick, and Jian Sun, “Faster R-CNN: Towards real-time object detection with region proposal networks,” inNIPS, 2015.

[6] Ekin D Cubuk, Barret Zoph, et al.,“Autoaugment: Learning augmentation strategies from data,” in CVPR, 2019.

[7] Younghak Shin, Hemin Ali Qadir, et al., “Automatic colon polyp detection using region based deep cnn and post learning approaches,”IEEE Access, 2018

[8] Yangqing Jia, Evan Shelhamer, et al., “Caffe: Convolutional architecture for fast feature embedding,” in ACMMM, 2014.