46 posts tagged with "type: research"

View All Tags

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 3)

1. Tasks and Benchmarks

In this section, we benchmark five tasks, including anticipation, recognition, segmentation, collision detection, and domain adaptation, to demonstrate the usefulness of the CathAction dataset. We then discuss the challenges and opportunities for improvement in each task.

A. Catheterization Anticipation

The anticipation task aims to predict the next catheterization action based on a sequence of frames. We adapt the conventional anticipation task framework in computer vision, introducing two timing parameters: anticipation time (τa\tau_a) and observation time (τo\tau_o). The anticipation time denotes the required duration to recognize an action, while the observation time indicates the length of the video footage to analyze before making a prediction. The objective is to predict the action class ( c_a ) for the frames within the anticipation time τa\tau_a, given the frames during the observation time τo\tau_o.

Network and Training. We leverage state-of-the-art action anticipation methods as baselines: CNN&RNN, RU-LSTM, TempAggRe-Fusion, AFFT, and Trans-SVNet. The future action predictions are supervised using cross-entropy loss with labeled future actions. Following prior works, we set τa=1s\tau_a = 1s and τo=1s\tau_o = 1s. Training was performed on a single Nvidia A100 GPU with a batch size of 64 for 80 epochs, starting with a learning rate of 0.001, reduced by a factor of 10 after epochs 30 to 60. We split approximately 80% of the dataset for training and 20% for testing. Performance metrics include top-1 accuracy, precision, and recall.

BaselineVenuesAccuracyPrecisionRecall
CNNCVPR 201828.9830.1429.76
RNNCVPR 201829.6430.3830.44
RU-LSTMCVPR 201935.0834.2934.77
TempAggReECCV 202034.6435.5634.71
Trans-SVNetIJCARS 202229.0619.6720.28
AFFTWACV 202337.9136.8737.63

Table 1: Catheterization anticipation results on the CathAction dataset. All values are reported in percentages (\%).

Figure 1

Figure 1: Qualitative catheterization prediction results. The predicted and ground truth of the next action are displayed on the right of each sample. The green color shows the correct prediction, and the red color shows the incorrect prediction.

Results. Table 1 shows the catheterization anticipation results of different baselines. This table shows that transformer-based methods show superior performance advantages over CNN or LSTM-based models. Qualitative results are illustrated in Fig 1. We can see that transformer-based models can make more accurate predictions in challenging scenarios, especially when the catheter is moving quickly or when the occlusion is partially obscured.

Discussion. Despite the advancements, existing methods for catheterization anticipation still struggle to achieve high accuracy, revealing areas for future research. The rapid motion of the catheter and guidewire poses significant challenges for this task, and real-time performance is crucial as surgeons require immediate feedback during procedures.

B. Catheterization Recognition

Following the traditional action recognition task in computer vision, in catheterization recognition, given an input video segment, our goal is to predict the action class for that video segment.

Network and Training. We explore state-of-the-art methods in action recognition to benchmark the catheterization recognition task, including TDN, Video Swin Transformer, and BERT Pretraining of Video Transformers (BEVT). Each model is trained using two Nvidia A100 GPUs for 80 epochs, with a mini-batch size of 512. The initial learning rates are set to 0.01 for the spatial stream and 0.001 for the temporal stream, reduced by a factor of 10 at the 20 and 40 epochs. All other parameters are re-used from the baseline methods.

Results. Table 2 show the catheterization recognition results of three baseline methods: TDN, Video Swin Transformer, and BEVT, on the CathAction dataset are summarized in the table below. TDN with ResNet101 achieves the best top-1 accuracy of 62.5% on five classes. Action recognition in endovascular intervention remains challenging due to the similarity in the appearance of catheters and guidewires across different environments, while actions depend on the visual characteristics of the catheters and guidewires.

BaselineVenuesAccuracyPrecisionRecall
TDN-ResNet50CVPR 202158.3459.1257.22
TDN-ResNet101CVPR 202162.5061.8962.77
Video Swin TransformerCVPR 202251.6752.1451.24
BEVTCVPR 202249.2850.2749.92

Table 2: Catheterization recognition results on the CathAction dataset. All values are reported in percentages (\%).

Discussion. Compared to the anticipation task (Table 1), catheterization recognition methods (Table 2) show higher accuracy. However, the overall performance is not yet significant enough for real-world applications. Further research can utilize advanced techniques such as multi-modality learning, combining pre-operative data or synthetic data with transfer learning to improve the results. Additionally, exploring the capabilities of large-scale medical foundation models is an interesting research direction.

C. Catheter and Guidewire Segmentation

Catheter and guidewire segmentation is a well-known task in endovascular interventions. In this task, we aim to segment the catheter and guidewire from the background. Unlike catheterization recognition or anticipation, which take a video as input, this segmentation task only uses the X-ray image as input.

BaselineDice ScoreJaccard IndexmIoUAccuracy
UNet51.6957.5131.1763.26
TransUNet56.5255.9334.1355.61
SwinUNet61.2659.5439.5376.60
SSL56.9556.8740.8772.24
SegViT63.4754.1242.4868.73

Table 3: Segmentation results on the CathAction dataset.

Network and Training. We benchmark U-Net, Trans-UNet, SwinUNet, and SegViT. We follow the default training and testing configurations provided in the published papers. We use the Dice Score, Jaccard Index, mIoU, and Accuracy as the evaluation metrics in the segmentation task.

Results. Table 3 shows the catheter and guidewire segmentation results. This table shows that the transformer-based networks such as TransUNet or SegViT achieve higher accuracy than the traditional UNet. The SegViT that utilizes the vision transformer backbone shows the best performance, however, the increase compared with other methods is not a large margin.

Discussion. In contrast to traditional segmentation tasks in computer vision, which typically involve objects occupying substantial portions of an image, the segmentation of catheters and guidewires presents a considerably greater challenge. These elongated instruments have extremely slender bodies, making their spatial presence in the image less pronounced. Additionally, the unique characteristics of X-ray images can lead to misidentification of catheters or guidewires as blood vessels. Addressing these challenges in future research is imperative to enhance the accuracy of segmentation outcomes.

D. Collision Detection

Detecting the collision of the tip of the catheter or guidewire to the blood vessel wall is an important task in endovascular intervention. We define the collision detection task as an object detection problem. In particular, the tip of the catheter or guidewire of all frames in our dataset is annotated with a bounding box. Each bounding box shares the class of either collision when the tip collides with the blood vessel, or normal when there is no collision with the blood vessel.

Network and Training. We use YOWO, YOWO-Plus, STEP, and HIT. Since the bounding boxes in our ground truth have relatively small sizes, we also explore tiny object detection methods such as Yolov and EFF. The training process starts with a learning rate of 0.0003, which is then decreased by a factor of 10 after 20 epochs, concluding at 80 epochs. We train all methods with a mini-batch size of 4 on an Nvidia A100 GPU. The average precision (AP) and mean average precision (mAP) are used to evaluate the detection results.

BaselineCollisionNormalMeanCollisionNormalMean
APmAP
STEP7.7911.2110.986.9211.299.08
YOWO8.3212.1811.737.4612.289.92
YOWO-Plus8.9212.2311.777.8612.4810.28
HIT9.3712.7412.148.1812.7210.81
Yolov*12.3021.0815.8911.8820.0414.11
EFF*13.7022.1016.9112.1420.7814.88

Table 4: Collision detection results on the CathAction dataset. The symbol (*) denotes tiny object detectors.

Figre 2

Figure 2: Qualitative results for the collision detection task. The first two columns visualize the collision results, the third column visualizes no collision cases, and the last column visualizes a failure case where the tip was not detected.

Results. Table 4 shows the collision detection results. This table indicates that tiny object detectors such as Yolov and EFF achieve higher accuracy compared to other normal object detectors. Furthermore, we observe that the performance of all methods remains relatively low. This highlights the challenges that lie ahead for collision detection in endovascular intervention. Figure 2 shows detection examples where EFF has difficulty detecting collisions between the catheter and the blood vessel.

Discussion. Compared to traditional object detection results on vision datasets, the collision detection results on our dataset are significantly lower, with the top mean AP being only 16.91. The challenges of this task come from two factors. First, the tip of the catheter or guidewire is relatively small in X-ray images. Second, the imbalance between the collision and normal class makes the problem more difficult. Therefore, there is a need to develop special methods to address these difficulties. Future works may rely on attention mechanisms, transformers, or foundation models to develop more sufficient endovascular collision detectors.

E. Domain Adaptation

Our dataset is sourced from two distinct environments: vascular phantom data and animal data. To assess the capacity for learning from phantom data and applying it to real data, we benchmark endovascular interventions under domain adaptation setups. For each task, we train the model on the phantom data and then test it on real animal data. In practice, animal data is similar to human data we capture from humans, and it is much more challenging to perform tasks on real animal or human data.

BaselineVenuesAccuracyPrecisionRecall
RU-LSTMCVPR 201922.9323.9122.57
TempAggReECCV 202017.1618.4118.23
Trans-SVNetIJCARS 202219.0617.6719.58
AFFTWACV 202325.6726.2926.33

Table 5: Catheterization anticipation results under domain adaptation setup. All methods are trained on phantom data and tested on animal data.

Anticipation Adaptation. We use the same methods RU-LSTM, TempAggRe, Trans-SVNet, and AFFT for anticipation adaptation experiments. Table 5 shows the results. Compared with the setup in Table 1, we can see that there is a significant accuracy drop. This highlights the challenges of applying baseline methods in practical real-world scenarios, particularly when dealing with unforeseen situations in catheterization procedures.

BaselineVenuesAccuracyPrecisionRecall
TDN-ResNet50CVPR 202124.1923.1724.56
TDN-ResNet101CVPR 202125.6224.5225.68
Video Swin TransformerCVPR 202228.7927.9828.12
BEVTCVPR 202231.2230.4831.79

Table 6: Catheterization recognition results under domain adaptation setup.

Recognition Adaptation. We repeat the catheterization recognition task under the domain adaptation setup. Table 6 shows the results when all baselines are trained on phantom data and tested on animal data. This table also demonstrates that training under domain adaption setup is very challenging, as compared to Table 2 under normal setting, the accuracy drops approximately 30%30\%.

APmAP
BaselineCollisionNormalMeanCollisionNormalMean
STEP1.532.121.871.091.981.62
YOWO2.124.113.091.973.682.92
YOWO-Plus1.181.431.211.071.261.09
HIT1.311.191.241.061.181.11
Yolov7.318.928.096.287.497.21
EFF8.279.168.197.618.297.88

Table 7: Collision detection results under domain adaptation setup. All methods are trained on phantom data and tested on animal data. The symbol (*) denotes tiny object detectors.

Collision Detection Adaptation. Table 7 shows the results collision detection results under domain adaptation. We can see that under domain adaptation setup, most object detection methods achieve very low accuracy. Therefore, there is an immediate need to improve or design new methods that can detect the collision in real-time for endovascular catheterization procedures.

BaselineDice ScoreJaccard IndexmIoUAccuracy
UNet26.5831.3812.1346.07
TransUNet16.1624.1917.2333.61
SwinUNet17.4138.147.5240.79
SSL26.9132.0418.7242.44
SegViT30.7432.2211.4650.00

Table 8: Domain adaptation segmentation results.

Segmentation Adaptation. Table 8 shows the catheter and guidewire segmentation results when the networks are trained on phantom data and tested on animal data. Similar to other tasks under the domain adaptation setting, we observe a significant accuracy drop in all methods. Overall, SegViT still outperforms other segmentation methods. This shows that the vision transformer backbone may be potentially a good solution for this task.

2. Discussion

We introduce CathAction, a large-scale dataset for endovascular intervention tasks, encompassing annotated ground truth for segmentation, action understanding, and collision detection. While CathAction marks a significant advancement in endovascular interventions, it is important to acknowledge certain limitations. First, despite its comprehensiveness, the dataset may not encompass every possible clinical scenario and could potentially lack representation for rare or outlier cases. Second, our work currently benchmarks vision-based methods, which exhibit insufficient accuracy, and persisting challenges in generalizability and adaptability to real-world scenarios are evident. This is highlighted by the results presented in Section 1 for all catheterization anticipation, recognition, segmentation, and collision detection tasks. Thirdly, we mostly utilize metrics from the vision community to evaluate the results. These metrics may not fully reflect the clinical needs, and the continuous refinement of evaluation metrics and exploration of potential interdependencies among tasks would demand further research.

From our intensive experiments, we see several research directions that benefit from our large-scale datasets: 1. There is an immediate need to develop more advanced methods for catheterization anticipation, recognition, collision detection, and action understanding, especially under domain adaptation setup. Future work can explore the potential of graph neural networks, temporal information, and multimodal or transfer learning to improve the accuracy and reliability of the methods. 2. Currently, we address endovascular intervention tasks independently; future work can combine those tasks and tackle them simultaneously (e.g., the anticipation and collision detection tasks can be jointly trained). This would make the research outputs more useful in clinical practice. 3. Given the fact that CathAction is a large-scale dataset, it can be used to train a foundation model for endovascular interventions or related medical tasks.

3. Conclusion

We introduce CathAction as a large-scale dataset for endovascular intervention research, offering the largest and most comprehensive benchmark to date. With intensive annotated data, CathAction addresses crucial limitations in existing datasets and helps connect computer vision with healthcare tasks. By providing a standardized dataset with public code and public metrics, CathAction promotes transparency, reproducibility, and the collective exploration of different tasks in the field. Our code and dataset are publicly available to encourage further study.

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 2)

1. The CathAction Dataset

This section introduces the CathAction dataset. Specifically, we describe the data collection process and annotation pipeline. We then present statistics regarding different aspects of our large-scale dataset.

Data Collection

Given that endovascular intervention constitutes a medical procedure, acquiring extensive human data is often impractical and time-consuming due to privacy constraints. To address this challenge, we suggest an alternative approach involving the collection of data from two distinct sources:

  1. Utilizing vascular soft silicone phantoms modeled after the human body.
  2. Employing animal subjects, specifically pigs. The selection of pigs is justified by their vascular anatomy, which is widely acknowledged as highly analogous to that of humans.

Ethics
Since our data collection involves experiments with radiation sources (X-ray radiology fluoroscopic systems) and live animals, all relevant ethical approvals were obtained in advance of the collection process. The human subjects who collect the data are well-trained and professional endovascular surgeons, wearing protective suits as part of daily practice in the hospital.

(a) Silicon phantom(b) Data collection setup
Figure-1aFigure-1b

Figure 1:: The human silicon phantom model (a), and data collection setup in the operating room (b).

Phantom Setup
To ensure that data is collected from various models, we use five adult human aortic arch phantoms made of soft silicone, manufactured by Elastrat Ltd., Switzerland. To enhance realism in the interaction between surgical tools and tissues, the phantoms are connected to a pulsatile pump to simulate the flow of normal human blood. All phantoms are placed beneath an X-ray imaging system to mimic a patient lying on an angiography table, preparing for an endovascular procedure.

Animal Setup
We use five live pigs as subjects for data collection. The animal setup is identical to that of a human procedure. During the endovascular intervention, professional surgeons use an iodine-based contrast agent to enhance visibility of specific structures or fluids within the body. Iodine contrast agents are radiopaque, meaning they absorb X-rays, resulting in improved visibility of blood vessels, organs, and other structures such as the catheter and guidewire during imaging.

Figure 2

Figure 2: Example data collected with phantom models (top row) and animals (bottom row). Animal data are more challenging with less visible catheters or guidewires.

Data Collection
Ten skilled professional surgeons are tasked with cannulating three arteries, namely the left subclavian (LSA), left common carotid (LCCA), and right common carotid (RCCA), using a commercial catheter and guidewire. Throughout each catheterization process, the surgeon operator activates the X-ray fluoroscopy using a pedal in the operating room. We developed a real-time image grabber to transmit the video feed of the surgical scene to a workstation. The experiments are conducted under two interventional radiology fluoroscopic systems: Innova 4100 IQ GE Healthcare and EMD Technologies Epsilon X-ray Generator. Fig 1 shows the data collection setup with human silicon phantoms and Fig 2 visualizes the collected data with phantom models and real animals. From 3, we can see that there is a huge domain gap between data collected using phantom models and live animals.

Data Annotation

Actions
Based on advice from expert endovascular surgeons, we define five classes to annotate catheterization actions. These classes fall into three groups: catheter (\texttt{advance catheter} and \texttt{retract catheter}), guidewire (\texttt{advance guidewire} and \texttt{retract guidewire}), and one action involving both the catheter and guidewire (\texttt{rotate}). Surgeons typically rotate both the catheter and guidewire simultaneously, so we use one rotation class. We utilize a free, open-source video editor to annotate the start and end times of each narrated action. All fluoroscopy videos are processed at a 500 x 500 resolution and 24 frames per second (FPS). To ensure annotation quality, all ground-truth actions are manually checked and modified by an experienced endovascular surgeon.

Collision Annotation
In practice, the collision between the catheter (or guidewire) and the blood vessel wall mainly occurs at the instrument's tip. Therefore, for each frame of the fluoroscopy video, we annotate the catheter (or guidewire) tip with a bounding box. There are two classes for the bounding boxes: \texttt{collision} (when the instrument collides with the blood vessel) and \texttt{normal} (when there is no collision). We used an open-source labeling tool to annotate bounding boxes in each video, with all videos encoded at 24 FPS to ensure dataset coherence.

Segmentation
The combination of guidewire and catheter is common in endovascular interventions, where precise navigation through blood vessels is essential for procedure success. Unlike most previous datasets that consider both catheter and guidewire as one class, we manually label catheter and guidewire classes separately in our dataset. Our segmentation ground truth thus provides a more detailed understanding of endovascular interventions.

Dataset Statistics

Overview
As summarized in Table 1 in the previous part 1, CathAction is a large-scale benchmark for endovascular interventions. Our dataset consists of approximately 500,000 annotated frames for action understanding and collision detection, and around 25,000 ground-truth masks for catheter and guidewire segmentation. There are a total of 569 videos in our dataset. Some collected video samples are illustrated in Fig 2. We believe CathAction is currently the largest, most challenging, and most comprehensive dataset of endovascular interventions.

Statistics
The CathAction dataset is annotated with a primary focus on catheters and guidewires. Fig. 3 provides an overview of the distribution of action classes in both animal and phantom data, while Fig. 4 portrays the distribution of action segment lengths, illustrating the substantial variability in segment duration. Additionally, Fig. 5 visually compares the number of bounding boxes between phantom and animal data, revealing a significant disparity between counts of normal and collision boxes, as expected due to the infrequency of collisions in real-world scenarios.

Figure 3

Figure 3: Distribution of the number of action classes in the CathAction dataset. Left-side: Distribution on real animal data. Right-side: Distribution on phantom data.

Figure 4

Figure 4: Duration distribution of segments' actions in the CathAction dataset, on real animal data and phantom data.

Figure 5

Figure 5: Comparison of the number of bounding box objects in real animal data and phantom data.

Adaptation Property
Since data is collected from two sources—phantoms and real animals—a domain gap exists between the two data types. Fig. 2 and Fig. 5 also demonstrate the adaptation property shared between phantom and animal data. This distinctive domain gap renders CathAction a formidable benchmark for evaluating domain adaptation, a critical problem in medical domains where collecting real human data is often infeasible. Using CathAction, we can develop domain adaptation techniques, learning from synthetic or phantom data and effectively applying that knowledge to genuine animal or human data, bridging the gap between controlled simulation and real-world scenarios.

Next

In the next post, we will benchmark our new dataset CathAction on various tasks.

CathAction - A Benchmark for Endovascular Intervention Understanding (Part 1)

Real-time visual feedback from catheterization analysis is crucial for enhancing surgical safety and efficiency during endovascular interventions. However, existing datasets are often limited to specific tasks, small scale, and lack the comprehensive annotations necessary for broader endovascular intervention understanding. To tackle these limitations, we introduce CathAction, a large-scale dataset for catheterization understanding. Our CathAction dataset encompasses approximately 500,000 annotated frames for catheterization action understanding and collision detection, and 25,000 ground truth masks for catheter and guidewire segmentation. For each task, we benchmark recent related works in the field. We further discuss the challenges of endovascular intentions compared to traditional computer vision tasks and point out open research questions. We hope that CathAction will facilitate the development of endovascular intervention understanding methods that can be applied to real-world applications. Intro

1. Introduction

DatasetCollectionType#FramesSourceAnnotationPublicTask
Barbu et al.X-rayVideo535RealManualNoSegmentation
Wu et al.3D EchoVideo800RealManualNoSegmentation
Ambrosini et al.X-rayImage948RealManualNoSegmentation
Mastmeyer et al.3D MRIImage101RealManualNoSegmentation
Yi et al.X-rayImage2,540SynthesisAutomaticNoSegmentation
Nguyen et al.X-rayImage25,271PhantomSemi-AutoNoSegmentation
Danilov et al.3D UltrasoundVideo225SyntheticManualNoSegmentation
Delmas et al.X-rayImage2,357SimulatedAutomaticNoReconstruction
Brost et al.X-rayImage938ClinicalSemi-AutoNoTracking
Ma et al.X-ray, CTImage1,048ClinicalManualNoReconstruction
CathAction (ours)X-rayVideo500,000+Phantom & AnimalManualYesSegmentation, Action Understanding, Collision Detection

Table 1: Endovascular intervention datasets comparison.

Cardiovascular diseases are one of the leading causes of death worldwide. Endovascular intervention has become the gold standard treatment for these diseases, preferred for its advantages over traditional open surgery, including smaller incisions, reduced trauma, and lower risks of comorbidities for patients. Endovascular interventions involve maneuvering small and long medical instruments, such as catheters and guidewires, within the vasculature through small incisions to reach targeted areas for treatment delivery, such as artery stenting, tissue ablation, and drug delivery. However, such tasks require high technical skill, with the primary challenge being to avoid collisions with the vessel wall, which could result in severe consequences, including perforation, hemorrhage, and organ failure. In practice, surgeons rely on 2D X-ray fluoroscopy images to perform these tasks within the 3D human body, which adds a significant challenge in safely controlling the catheter and guidewire.

Recently, learning-based methods for computer-assisted intervention systems have emerged for diverse tasks. Numerous methodologies have been developed to address the challenges of endovascular interventions, including catheter and guidewire segmentation, vision-based force sensing, learning from demonstration, and skill training assistance. Additionally, various deep learning approaches have been proposed for specific tasks in endovascular interventions, such as instrument motion recognition in X-ray sequences, interventionalist hand motion recognition, and collision detection. However, due to challenges in acquiring medical data, most of these methods rely on synthetic data or small, private datasets. Consequently, despite the critical nature of interventions, current methods have not fully capitalized on recent advancements in deep learning, which typically require large-scale training data.

Over the years, several datasets for endovascular intervention have been introduced. Table 1 shows a detailed comparison between current endovascular intervention datasets. However, these datasets share common limitations. First, they are relatively small in terms of the number of images, as collecting real-world medical data is costly. Second, due to privacy challenges in the medical domain, most existing datasets are kept private. Finally, these datasets are often created for a single task, such as segmentation, and do not support other important tasks in endovascular interventions, such as collision detection or action understanding.

Intro

To address these issues, we present CathAction, a large-scale dataset encompassing several endovascular intervention tasks, including segmentation, collision detection, and action understanding. To our knowledge, CathAction represents the largest and most realistic dataset specifically tailored for catheter and guidewire tasks.

In summary, we make the following contributions:

  • We introduce CathAction, a large-scale dataset for endovascular interventions, providing manually labeled ground truth for segmentation, action understanding, and collision detection.
  • We benchmark key tasks in endovascular interventions, including catheterization anticipation, recognition, segmentation, and collision detection.
  • We discuss the challenges and open questions in endovascular intervention. Our code and dataset are publicly available.

2. Related Work

Endovascular Intervention Dataset
Several endovascular intervention datasets have been introduced. Barbu et al. proposed a dataset that effectively localizes the entire guidewire and validated it using a traditional threshold-based method. Other datasets consider fluoroscopy videos at the image level, with mask annotations for each frame from the fluoroscopy videos. For instance, Ambrosini et al. developed a dataset with 948 annotated mask segmentation instances considering both catheter and guidewire as one class. Similarly, Mastmeyer et al. collected and annotated a dataset with 101 segmentation masks for the real catheter from 3D MRI data. More recently, Nguyen et al. proposed a dataset that considers both catheter and guidewire as one class. Overall, most of these datasets have limitations in terms of size, task categories, and focus. To overcome these limitations, we introduce CathAction, a large-scale dataset with various tasks, including catheter and guidewire segmentation, collision detection, and catheter action recognition and anticipation. The CathAction dataset enables the development of more accurate and reliable deep learning methods for endovascular interventions.

Catheterization Action Understanding
Deep learning techniques have demonstrated notable achievements in endovascular intervention action understanding. Jochem et al. presented one of the first works utilizing deep learning for catheter and guidewire activity recognition in fluoroscopy sequences. Subsequently, deep learning-based approaches have gained prominence as the most widely utilized solution for interventionalist hand motion recognition. For instance, Akinyemi et al. introduced a deep learning model based on convolutional neural networks (CNNs) that incorporates convolutional layers for automatic feature extraction and identifies operators' actions. Additionally, Wang et al. proposed a multimodal fusion architecture for recognizing eight common operating behaviors of interventionists. Despite extensive research on deep learning methods for endovascular intervention, it comes with the limitation of medical data: most of these methods use synthetic data or small, private datasets. This leads to the fact that although intervention is a crucial procedure, it has not fully benefited from recent deep learning advancements, where large-scale training data are usually required.

Catheter and Guidewire Segmentation
Catheter and guidewire segmentation is crucial for real-time endovascular interventions. Many methods have been proposed to address the challenges of catheter and guidewire segmentation. The outcomes of catheter and guidewire segmentation can be applied in vision-based force sensing, learning from demonstration, or skill training assistance applications. Traditional methods for catheterization segmentation adopt thresholding-based techniques and do not generalize well on X-ray data. Deep learning methods can learn meaningful features from input data, but they are challenging to apply to catheter segmentation due to the lack of real X-ray data and the tediousness of manual ground truth labeling. Many current learning-based techniques for catheter segmentation and tracking are limited to training on small-scale datasets or synthetic data due to the challenges of large-scale data collection. Our dataset provides manual ground truth labels for both the catheter and guidewire, offering substantial development for catheter and guidewire segmentation.

Collision Detection
Collision detection is a crucial task in endovascular interventions to ensure patient safety. Several attempts have been made to incorporate deep learning models into collision detection, but these methods have focused on identifying risky actions in simulated datasets. While existing methods can be useful for identifying potential hazards, they cannot localize the position of collisions or provide visual feedback. Additionally, these methods have not been widely used in real-world settings due to the lack of annotated bounding boxes for collisions of guidewire tips with vessel walls. Our dataset addresses this limitation by providing annotated bounding boxes for collision events in both phantom and real-world data. This enables the development of deep learning models that can detect collisions in real-time and provide visual or haptic feedback to surgeons.

Next

In the next post, we will describe our new dataset CathAction.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 3)

Grasping Machine

1. Experiments

Experiment Setup

Dataset. We use the Grasp-Anything dataset in our experiment. Grasp-Anything is a large-scale dataset for language-driven grasp detection with 1M samples. Each image in the dataset is accompanied by one or several prompts describing a general object grasping action or grasping an object at a specific location. Dataset Visualization Evaluation Metrics. Our primary evaluation metric is the success rate, defined similarly to previous works, necessitating an IoU score of the predicted grasp exceeding 25% with the ground truth grasp and an offset angle less than 3030^\circ. We also use the harmonic mean (`H') to measure the overall success rates. All methods' latency (inference time) in seconds is reported using the same NVIDIA A100 GPU.

Comparison with Grasp Detection Methods

BaselineSeenUnseenHLatency
GR-ConvNet0.370.180.240.022
Det-Seg-Refine0.300.150.200.200
GG-CNN0.120.080.100.040
CLIPORT0.360.260.290.131
CLIP-Fusion0.400.290.330.157
MaskGrasp0.500.460.450.116
LLGD (ours) with 1 timestep0.470.340.400.035
LLGD (ours) with 3 timesteps0.520.380.450.106
LLGD (ours) with 10 timesteps0.530.390.460.264

Table 1: Comparision with Traditional Grasp Detection Methods.

We compare our LLGD with GR-CNN, Det-Seg-Refine, GG-CNN, CLIPORT, MaskGrasp, and CLIP-Fusion. Table 1 compares our method and other baselines on the GraspAnything dataset. This table shows that our proposed LLGD outperforms traditional grasp detection methods by a clear margin. Our inference time is also competitive with other methods.

Comparison with Lightweight Diffusion Models

MethodSeenUnseenHLatency
LGD with 3 timesteps0.420.290.350.074
LGD with 30 timesteps0.490.410.450.741
LGD with 1000 timesteps0.520.420.4726.12
SnapFusion with 500 timesteps0.490.370.4312.95
LightGrad with 250 timesteps0.510.340.436.420
LLGD (ours) with 1 timestep0.470.340.400.035
LLGD (ours) with 3 timesteps0.520.380.450.106
LLGD (ours) with 10 timesteps0.530.390.460.264

Table 2: Comparison with Diffusion Models for Language-Driven Grasp Detection.

In this experiment, we compare our LLGD with other diffusion models for language-driven grasp detection. In particular, we compare with LGD using DDPM, and recent lightweight diffusion works: SnapFusion with 500 timesteps and LightGrad with 250 timesteps.

Table 2 shows the result of diffusion models for language-driven grasp detection. We can see that the accuracy and inference time of the classical diffusion model LGD strongly depend on the number of denoising timesteps. LGD with 1000 timesteps achieves reasonable accuracy but has significant long latency. Lightweight diffusion models such as SnapFusion and LightGrad show reasonable results and inference speed. However, our method achieves the highest accuracy with the fastest inference speed.

Conditional Consistency Model Demonstration

Figure 1

Figure 1: Consistency model analysis. With text prompt input "Grasp the cup at its handle", we compare the trajectory grasp pose of our method and LGD. In the figure, the top row illustrates the trajectory of LGD, while the bottom row corresponds to the trajectory of our LLGD.

In this analysis, we will verify the effectiveness of our conditional consistency model. In Figure 1, we visualize the grasp pose aspect to time index tt. In the LGD model, as the discrete diffusion model is employed with T=1000T=1000, we have to perform the diffusion steps with a step size of 1, which results in a very slow inference speed. Moreover, the grasp pose trajectory still exhibits significant fluctuations. Our method can arbitrarily select boundary time points for the continuous consistency model. It is evident that the number of iterations required by our method is significantly less than that of LGD for the exact value of TT, which contributes to the "lightweight" factor. Furthermore, the grasp pose at t=603t=603 has almost converged to the ground truth, while LGD using DDPM at t=350t=350 has not yet achieved a successful grasp.

Ablation Study

Figure 2

Figure 2: Visualization of detection results of different language-driven grasp detection methods.

Visualization. Figure 2 shows qualitative results of our method and other baselines. The outcomes suggest that our method LLGD generates more semantically plausible grasp poses given the same text query than other baselines. In particular, other methods usually show grasp poses at a location not well-aligned with the text query, while our method shows more suitable detection results.*

Figure 3

Figure 3: In the wild detection results. Images are from the internet.

In the Wild Detection. Figure 3 illustrates the outcomes of applying our method to random images from the internet. The results demonstrate that our LLGD can effectively detect the grasp pose given the language instructions on real-world images. Our method showcases a promising zero-shot learning ability, as it successfully interprets grasp actions on images it has never encountered during training.

Figure 4

Figure 4: Prediction failure cases.

Failure Cases. Although promising results have been achieved, our method predicts incorrect grasp poses. Many objects and grasping prompts pose a challenging problem as the network cannot capture all the diverse circumstances that arise in real life. Figure 4 depicts some failure cases where LLGD incorrectly predicts the results, which can be attributed to multiple similar objects that are difficult to distinguish and text prompts that lack detailed descriptions for accurate result determination.

Robotic Experiments

Robotic Setup. Our lightweight language-driven grasp detection pipeline is incorporated within a robotic grasping framework that employs a KUKA LBR iiwa R820 robot to deliver quantifiable outcomes. Utilization of the RealSense D435i camera enables the translation of grasping information from LLGD into a 6DoF grasp posture, bearing resemblance to previous works. Subsequently, a trajectory optimization planner is used to execute the grasping action. Experiments were conducted on a table surface for the single object scenario and the cluttered scene scenario, wherein various objects were placed to test each setup. Table 3 shows the success rate of our method and other baseline models.

Baseline         SingleCluttered
GR-ConvNet + CLIP         0.330.30
Det-Seg-Refine + CLIP         0.300.23
GG-CNN + CLIP         0.100.07
CLIPORT0.270.30
CLIP-Fusion0.400.40
SnapFusion0.400.39
LightGrad0.410.40
LLGD (ours)0.430.42

Table 3: Robotic language-driven grasp detection results.

Our method outperforms other baselines in both single object and cluttered scenarios. Furthermore, our lightweight model allows rapid execution speed without sacrificing the accuracy of visual grasp detection.

2. Discussion

Limitation. Despite achieving notable results in real-time applications, our method still has limitations and predicts incorrect grasp poses in challenging real-world images. Faulty grasp poses are often due to the correlation between the text and the attention map of the visual features not being well-aligned as Figure 4. From our experiment, we see that when grasp instruction sentences contain rare and challenging nouns that are popular in the dataset, ambiguity in parsing or text prompts occurs, which is usually the main cause of incorrect predictions of grasp poses. Therefore, providing the instruction prompts with clear meanings is essential for the robot to understand and execute the correct grasping action.

Future work. We see several prospects for improvement in future work: 1. Expanding our method to handle 3D space is essential, implementing it for 3D point clouds and RGB-D images to avoid the lack of depth information in robotic applications. 2. Addressing the gap between the semantic concept of text prompts and input images, analyzing the detailed geometry of objects to effectively distinguish between items with similar structures. 3. Expanding the problem to more complex language-driven manipulation applications. For instance, if the robot wants to grasp a plate containing apples, it would need to manipulate the objects in such a manner that prevents the apples from falling.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 2)

Grasping Machine

1. Lightweight Language-driven Grasp Detection

Overview

Given an input RGB image and a text prompt describing the object of interest, we aim to detect the grasping pose on the image that best matches the text prompt input. We follow the popular rectangle grasp convention widely used in previous work to define the grasp.

In the diffusion model, we represent the target grasp pose as x0\mathbf{x}_0. The objective of our diffusion process of language-driven grasp detection involves denoising from a noisy state xT\mathbf{x}_T to the original grasp pose x0\mathbf{x}_0, conditioned on the input image and grasp instruction represented by yy. The forward process in traditional conditional diffusion model is defined as:

q(xtxt1)=N(1βtxt1,βtI) ,(1)q(\mathbf{x}_t|\mathbf{x}_{t-1}) = \mathcal{N}(\sqrt{1-\beta_t}\mathbf{x}_{t-1},\beta_t\mathbf{I})~, \tag{1}

where the hyperparameter βₜ is the amount of noise added at diffusion step t ∈ [0,T] ⊆ ℝ.

To train a diffusion model with condition y, we use a neural network to learn the reverse process:

pϕ(xt1xt,y)=N(μϕ(xt,t,y),Σϕ(xt,t,y)) .(2)p_\phi(\mathbf{x}_{t-1}|\mathbf{x}_t,y) = \mathcal{N}(\mu_\phi(\mathbf{x}_t,t,y),\Sigma_\phi(\mathbf{x}_t,t,y))~. \tag{2}

In our approach, we utilize the diffusion process in the continuous domain, where xt\mathbf{x}_t is the grasp pose state at arbitrary time index tt. Unlike popular discrete diffusion models, by using a continuous space, we can improve sample quality and reduce inference times due to the ability to traverse the diffusion process at arbitrary timesteps, allowing for more fine-grained control over the denoising process.

Method Overview

Figure 1: The overview of our method. First, the input RGB image and text prompt are fed into the feature encoder and ALBEF fusion. Subsequently, we concurrently train two models with the same architectures: A score network to estimate the probability flow Ordinary Differential Equation (ODE) trajectory for the diffusion process and a conditional consistency model to determine the grasp pose with a few denoising steps.

Conditional Consistency Model for LLGD

To reduce the inference time during the denoising step of the diffusion model, we aim to estimate the original grasp pose with just a few denoising steps. Since our language-driven grasp detection task has the condition yy, we introduce a conditional consistency model based on the consistency concept to infer the original grasp pose during the inference process directly:

fθ(xt,t,y)={xtt[0,ϵ]Fθ(xt,t,y)t(ϵ,T] ,(3)\mathbf{f}_\theta(\mathbf{x}_t,t,y) = \begin{cases} \mathbf{x}_t & t \in [0,\epsilon] \\ \mathbf{F}_\theta(\mathbf{x}_t,t,y) & t \in (\epsilon,T] \end{cases}~, \tag{3}

where fθ(xϵ,t,y)=xϵ\mathbf{f}_\theta(\mathbf{x}_\epsilon, t, y) = \mathbf{x}_\epsilon is the boundary condition, and Fθ(xt,t,y)\mathbf{F}_\theta(\mathbf{x}_t,t,y) is a free-form deep neural network whose output has the same dimensionality as xt\mathbf{x}_t.

To train our conditional consistency model, we employ knowledge distillation from a continuous diffusion process:

dxt=12γtxtdt+γtdwt ,(4)d\mathbf{x}_{t} = -\frac{1}{2}\gamma_t\mathbf{x}_t dt + \sqrt{\gamma_t} d\mathbf{w}_t~, \tag{4}

where γt\gamma_t is a non-negative function referred to as the noise schedule, and wt\mathbf{w}_t is the standard Brownian motion. This forward process creates a trajectory of grasp poses {xt}t=0T\{\mathbf{x}_t\}_{t=0}^T. The grasp pose state xt\mathbf{x}_t depends on the time index tt and the input image and text prompt. The grasp distribution p(x0y)p(\mathbf{x}_0|y) from the dataset is transformed into p(xTy)N(0,I)p(\mathbf{x}_T|y) \sim \mathcal{N}(0, \mathbf{I}). Given the ground truth grasp pose x0\mathbf{x}_0, we can sample xt\mathbf{x}_t at arbitrary tt:

p(xtx0)=N(μt,Σt) ,(5)p(\mathbf{x}_t|\mathbf{x}_0) = \mathcal{N}(\mu_t, \Sigma_t)~, \tag{5}

where

μt=e12ρtx0,Σt=(1eρt)I,ρt=0tγsds .\mu_t = e^{\frac{1}{2}\rho_t} \mathbf{x}_0, \Sigma_t = (1 - e^{\rho_t})\mathbf{I}, \rho_t = -\int_{0}^{t} \gamma_s ds~.

The equation (4) is a probability flow ODE. With the conditional variable yy, it can be redefined as:

dxtdt=12γt[xt+logp(xty)] ,(6)\frac{d\mathbf{x}_t}{dt} = -\frac{1}{2}\gamma_t\left[\mathbf{x}_t + \nabla\log p(\mathbf{x}_t|y)\right]~, \tag{6}

where logp(xty)\nabla\log p(\mathbf{x}_t|y) is the score function of the conditional diffusion model.

Suppose that we have a neural network sϕ(xt,t,y)\mathbf{s}_\phi(\mathbf{x}_t, t, y) that can approximate the score function logp(xty)\nabla\log p(\mathbf{x}_t|y), i.e., sϕ(xt,t,y)logp(xty)\mathbf{s}_\phi(\mathbf{x}_t, t, y) \approx \nabla\log p(\mathbf{x}_t|y). After training the score network, we can replace the logp(xty)\nabla\log p(\mathbf{x}_t|y) term in the equation (6) with a neural network:

dxtdt=12γt[xt+sϕ(xt,t,y)] .(7)\frac{d\mathbf{x}_t}{dt} = -\frac{1}{2}\gamma_t\left[\mathbf{x}_t + \mathbf{s}_\phi(\mathbf{x}_t, t, y)\right]~. \tag{7}

Score Function Loss. In order to approximate the score function logp(xty)\nabla\log p(\mathbf{x}_t|y), the conditional denoising estimator minimizes the following objective:

Lscore=EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2] ,(8)\mathcal{L}_{\rm score}=\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~, \tag{8}

where λ(t)R+\lambda(t) \in \mathbb{R}^+ is a positive weighting function.

Proposition 1. Suppose that xt\mathbf{x}_t is conditionally independent of yy given x0\mathbf{x}_0, then minimizing Lscore\mathcal{L}_{\rm score} is the same as minimizing:

EtU[0,T]xt,yp(xt,y)[λ(t)logp(xty)sϕ(xt,t,y)2] .\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_t,y \sim p(\mathbf{x}_t,y) \\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~.

Proof. Because xt\mathbf{x}_t is conditionally independent of yy given x0\mathbf{x}_0, we have:

EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)x0p(x0y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)x0p(x0y)xtp(xtx0,y)[λ(t)logp(xtx0,y)sϕ(xt,t,y)2]=EtU[0,T]yp(y)[Φ(t,y)] ,(9)\begin{aligned} &\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0,y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \end{subarray} }\left[\Phi(t,y)\right]~, \tag{9} \end{aligned}

where

Φ(t,y)=Ex0p(x0y)xtp(xtx0,y)[λ(t)logp(xtx0,y)sϕ(xt,t,y)2] .\begin{aligned} &\Phi(t,y)\\ &=\mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_0 \sim p(\mathbf{x}_0|y)\\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0,y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \end{aligned}

If yy and tt are fixed, we can define a transition probability that does not depend on these variables, q(x0)=p(x0y)q(\mathbf{x}_0) = p(\mathbf{x}_0|y), κ(xt)=sϕ(xt,t,y)\kappa(\mathbf{x}_t)=\mathbf{s}_\phi(\mathbf{x}_t,t,y). According to Vincent P., 2011, we have:

Φ(t,y)=Ex0q(x0)xtq(xtx0)[λ(t)logq(xtx0)κ(xt)2]=E(x0,xt)q(x0,xt)[λ(t)logq(xtx0)κ(xt)2]=Extq(xt)[λ(t)logq(xt)κ(xt)2]=Extp(xty)[λ(t)logp(xty)sϕ(xt,t,y)2] .(10)\begin{aligned} \Phi(t,y) &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_0 \sim q(\mathbf{x}_0)\\ \mathbf{x}_t \sim q(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t|\mathbf{x}_0) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} (\mathbf{x}_0,\mathbf{x}_t) \sim q(\mathbf{x}_0,\mathbf{x}_t)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t|\mathbf{x}_0) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_t \sim q(\mathbf{x}_t)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log q(\mathbf{x}_t) - \kappa(\mathbf{x}_t)\|^2 \right] \\ &= \mathbb{E}_{ \begin{subarray}{l} \mathbf{x}_t \sim p(\mathbf{x}_t|y)\\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \tag{10} \end{aligned}

From the equations (9) and (10), we can prove the equivalence of the two objective functions.

EtU[0,T]x0,yp(x0,y)xtp(xtx0)[λ(t)logp(xtx0)sϕ(xt,t,y)2]=EtU[0,T]yp(y)xtp(xty)[λ(t)logp(xty)sϕ(xt,t,y)2]=EtU[0,T](xt,y)p(xt,y)[λ(t)logp(xty)sϕ(xt,t,y)2] .(11)\begin{aligned} &\mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|\mathbf{x}_0) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|\mathbf{x}_0) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ =& \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ y \sim p(y) \\ \mathbf{x}_t \sim p(\mathbf{x}_t|y) \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right] \\ =& \mathbb{E}_{ \begin{subarray}{l} t \sim \mathcal{U}[0, T] \\ (\mathbf{x}_t,y) \sim p(\mathbf{x}_t,y) \\ \end{subarray} }\left[\lambda(t) \|\nabla\log p(\mathbf{x}_t|y) - \mathbf{s}_\phi(\mathbf{x}_t,t,y)\|^2 \right]~. \tag{11} \end{aligned}

Discretization. Consider discretizing the time horizon [ϵ,T][\epsilon,T] into N1N-1 with boundary t1=ϵ<t2<t3<<tN=Tt_1=\epsilon<t_2<t_3<\ldots<t_{N}=T. If NN is sufficiently large, we can use an ODE-solver to estimate the next discretization step:

x^ti=xti+1+(titi+1)dxdtt=ti+1\hat{\mathbf{x}}_{t_i} = \mathbf{x}_{t_{i+1}} + (t_i - t_{i+1}) \left. \frac{d\mathbf{x}}{dt} \right|_{t = t_{i+1}}
=xti+112γi+1(titi+1)[xti+1+sϕ(xt,t,y)] .(12)= \mathbf{x}_{t_{i+1}} - \frac{1}{2}\gamma_{i+1} (t_i - t_{i+1})\left[\mathbf{x}_{t_{i+1}} + \mathbf{s}_\phi(\mathbf{x}_t,t,y)\right]~. \tag{12}

Conditional Consistency Model Loss. To enable fast sampling, we expect that the predicted point x^ti\hat{\mathbf{x}}_{t_i} and xti+1\mathbf{x}_{t_{i+1}} to lie on the same probability flow ODE trajectory. We propose conditional consistency loss to enforce this constraint:

Lconsistency=EiU[1,N1]xti+1p(xti+1x0)[λ(ti)fθ(xti+1,ti+1,y)fθ(x^ti,ti,y)2] ,(13)\mathcal{L}_{\rm consistency} = \mathbb{E}_{ \begin{subarray}{l} i \sim \mathcal{U}[1, N - 1] \\ \mathbf{x}_{t_{i+1}} \sim p(\mathbf{x}_{t_{i+1}}|\mathbf{x}_0) \end{subarray} } \left[\lambda(t_i) \|\mathbf{f}_\theta(\mathbf{x}_{t_{i+1}},t_{i+1},y) - \mathbf{f}_{\theta^*}(\hat{\mathbf{x}}_{t_{i}},t_{i},y)\|^2 \right]~, \tag{13}

where x^ti\hat{\mathbf{x}}_{t_i} is calculated in Equation 12, xti+1\mathbf{x}_{t_{i+1}} is sampling from Gaussian distribution in Equation 5, and θ\theta is the parameters of neural network f\mathbf{f}.

Additionally, we need to minimize the discrepancy between the predicted and ground truth grasp poses with the detection loss:

Ldetection=EiU[1,N]xtiN(μti,Σti)x0,yp(x0,y)[λ(ti)fθ(xti,ti,y)x02] .(14)\mathcal{L}_{\rm detection} = \mathbb{E}_{ \begin{subarray}{l} i \sim \mathcal{U}[1, N] \\ \mathbf{x}_{t_{i}} \sim \mathcal{N}(\mu_{t_{i}},\Sigma_{t_{i}}) \\ \mathbf{x}_0,y \sim p(\mathbf{x}_0,y) \end{subarray} }\left[\lambda(t_i)\|\mathbf{f}_\theta(\mathbf{x}_{t_i}, t_i, y) - \mathbf{x}_0\|^2\right]~. \tag{14}

The overall training objective for our method is:

Ltotal=Lconsistency+Ldetection .(15)\mathcal{L}_{\rm total} = \mathcal{L}_{\rm consistency} + \mathcal{L}_{\rm detection}~. \tag{15}

Network Details

The input of our network is the image and a corresponding grasping text prompt represented as ee (for example, "grasp the fork at its handle"). We first extract the image feature using a 12-layer vision transformer ViT image encoder. The input text prompt is encoded by a text encoder using BERT or CLIP. We then combine and learn the features of the input text prompt and input image using the ALBEF fusion network. The output of the fusion features is fed into a score network, and our conditional consistency model is used to learn the grasp pose. Figure 1 shows the detail of our network.

Score Network. In practice, we utilize a score network composed of several MLP layers to extract three components: the noisy grasp pose xt\mathbf{x}_t, the time index tt, and the conditional vision-language embedding yy. Subsequently, these features are concatenated, and the score function is extracted through a final MLP layer. It is crucial to ensure that the output dimension of the scoring network is identical to the dimension of the input xt\mathbf{x}_t because, fundamentally, the score function is the gradient of the grasp pose distribution given the condition yy. Our conditional consistency model's network has an architecture similar to the scoring network; however, its output is the predicted grasp pose.

Algorithm 1: Inference Process

Input: Image and text prompt, conditional consistency model fθ(x,t,y)\mathbf{f}_\theta(\mathbf{x},t,y), number of inference steps PP, sequence of time points t1=ϵ<t2<t3<<tP=Tt_1 = \epsilon < t_2 < t_3 < \dots < t_{P} = T, noise scheduler αt=eρt\alpha_t = e^{\rho_t}.

yALBEF (image, prompt)y \gets \text{ALBEF (image, prompt)}

Initial grasp noise xTN(0,I)\mathbf{x}_T \sim \mathcal{N}(0,\mathbf{I})

x0fθ(xT,T,y)\mathbf{x}_0 \gets \mathbf{f}_\theta(\mathbf{x}_T,T,y)

For i=P1i = P - 1 to 22:

  • Sample zN(0,I)\mathbf{z} \sim \mathcal{N}(0,\mathbf{I})
  • xtiαtix0+1αtiz\mathbf{x}_{t_i} \gets \sqrt{\alpha_{t_i}}\mathbf{x}_0 + \sqrt{1 - \alpha_{t_i}}\mathbf{z}
  • x0fθ(xti,ti,y)\mathbf{x}_0 \gets \mathbf{f}_\theta(\mathbf{x}_{t_i},t_i,y)

Output: Final grasp pose x0\mathbf{x}_0

Training and Inference

During training, we freeze the text and image encoder, then train the ALBEF fusion, the scoring network, and the consistency model end-to-end. The score network and the conditional consistency model share the same architecture. We trained both models simultaneously for 1000 epochs with a batch size of 8 using the Adam optimizer. The training time takes approximately three days on an NVIDIA A100 GPU. Regarding the parameters of the conditional consistency model, we empirically set T=1000T = 1000, ϵ=1\epsilon = 1, and N=2000N = 2000. After training the scoring network and the conditional consistency model fθ(xt,t,y)\mathbf{f}_\theta(\mathbf{x}_t,t,y), we can sample the grasp pose given the input image and language instruction prompt in a few denoising steps using our algorithm 1.

Method Overview Figure 2: Robot Hands with different ultilities.

Next

In the next post, we will evaluate the effectiveness of our proposal.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 4)

We evaluate our proposed dataset, Guide3D, through a structured experimental analysis, as follows: i) initially, we assess the dataset’s validity, focusing on reprojection errors and their distribution across the dataset to understand its accuracy; ii) we then explore the applicability of Guide3D in a 3D reconstruction task ; and iii) finally, we benchmark several segmentation algorithms to assess their performance on Guide3D, providing insights into the dataset’s utility.

Dataset Validation

Our analysis revealed a non-uniform distribution of reprojection errors across the dataset, with the highest variability and errors concentrated at the proximal end of the guidewire reconstructions. Figure 1 shows the reprojection error patterns for both Camera A and Camera B. For Camera A, mean errors increase from approximately 6 px to a peak of 20 px, with standard deviations rising from 5 px to 11 px, indicating growing inaccuracies and variability over time. Significant fluctuations around indices 25 to 27 highlight periods of particularly high error. For Camera B, mean errors exhibit an initial peak of 9 px at index 1, followed by fluctuations that decrease towards the end. The standard deviations for Camera B start high at 11 px and decrease over time, reflecting a pattern of high initial variability that stabilizes later. These patterns are consistent with the inherent flexibility of the guidewire, which can form complex shapes such as loops.

Figure 1. Guidewire Reconstruction Error Analysis: (Left) Illustrates the distribution of reprojection errors, noting higher variability and peak errors in the mid-sections and reduced errors at the extremities. (Right) Presents the results of reconstruction validation.

Furthermore, we conducted a validation procedure using CathSim, incorporating the aortic arch model described in next subsection and a guidewire of similar diameter and properties. For sampling, we employed the soft actor-critic (SAC) algorithm with segmented guidewires and kinematic data, producing realistic validation samples. Evaluation metrics included maximum Euclidean distance (MaxED) at 2.880 ± 0.640 mm, mean error in tip tracking (METE) at 1.527 ± 0.877 mm, and mean error related to the robot’s shape (MERS) at 0.001 ± 0.000. These results demonstrate the method’s precision.

Guidewire Prediction Results

We now demonstrate the capability of the introduced network and highlight the importance of the proposed dataset. We examine the network prediction in the following manner: 1) we first conduct an analysis between the predicted and reconstructed curve by employing piecewise metrics, and 2) we showcase the reprojection error.

Shape Prediction Errors: Table 1 presents the comparison of different metrics for shape prediction accuracy. We quantify the shape differences using the following metrics: 1) Maximum Euclidean Distance (MaxED), 2) Mean Error in Tip Tracking (METE), and 3) Mean Error in Robot Shape (MERS). For all the metrics, the shape of the guidewire, represented as a 3D curve C(u)\mathbf{C}(u), is sampled at equidistant Δu\Delta u intervals along the arclength parameter uu. Therefore, the metrics represent the pointwise discrepancies between the two shapes along the curve’s arclength.

The results indicate that the spherical representation consistently outperforms the Cartesian representation across all metrics. Specifically, the Maximum Euclidean Distance (MaxED) shows a lower error in the spherical representation (6.88 ± 5.23 mm) compared to the Cartesian representation (10.00 ± 4.64 mm). Similarly, the Mean Error in Tip Tracking (METE) is significantly lower in the spherical representation (3.28 ± 2.59 mm) than in the Cartesian representation (6.93 ± 3.94 mm). For the Mean Error in Robot Shape (MERS), the spherical representation also demonstrates a reduced error (4.54 ± 3.67 mm) compared to the Cartesian representation (5.33 ± 2.73 mm). Lastly, the Fréchet distance shows a smaller error for the spherical representation (6.70 ± 5.16 mm) compared to the Cartesian representation (8.95 ± 4.37 mm). These results highlight the advantage of using the spherical representation for more accurate shape prediction.

Table 1 Shape Comparison (mm).

Shape Comparison Visualization: Figure 2a showcases two 3D plots from different angles, comparing the ground truth guidewire shape to the predicted shape by the network. The network demonstrates its capability to accurately predict the guidewire shape, even in the presence of a loop and self-obstruction in the image. The predicted shape aligns closely with the actual configuration of the guidewire. Notably, the proximal end manifests a more substantial error relative to the nominal error seen at the distal end. Discrepancies from the authentic guidewire shape span from a mere 2 mm at the distal end to a noticeable 5 mm at the proximal end. Impressively, the network evidences its capability to accurately predict the guidewire’s shape using only consecutive singular plane images. Subsequently, the 3D points are reprojected onto the original images, as illustrated in Figure 2b.

Figure 2. The figure illustrates the reconstruction similarity of the guidewire when reprojected onto the images. It demonstrates the network’s capability to accurately predict the guidewire shape, even in the presence of noticeable angles, highlighting the robustness of the prediction model.

Segmentation Results

We demonstrate Guide3D’s potential to advance guidewire segmentation research by evaluating the performance of three state-of-the-art network architectures: UNet (learning rate: 1×1051 \times 10^{-5}, 135 epochs), TransUnet (integrating ResNet50 and Vision Transformer (ViT-B-16), learning rate: 0.01, 199 epochs), and SwinUnet (Swin Transformer architecture, learning rate: 0.01, 299 epochs). Performance metrics included the Dice coefficient (DiceM), mean Intersection over Union (mIoU), and Jaccard index, detailed in Table 2. The results indicate that UNet achieved a DiceM of 92.25, mIoU of 36.60, and Jaccard index of 86.57. TransUnet outperformed with a DiceM of 95.06, mIoU of 41.20, and Jaccard index of 91.10. SwinUnet recorded a DiceM of 93.73, mIoU of 38.58, and Jaccard index of 88.55. These findings benchmark the dataset’s performance and suggest potential for future enhancements. Despite these promising results, the presence of loops and occlusions within the guidewire indicates that polyline prediction could significantly improve task utility.

Table 2 Segmentation Results.

Discussion and Conclussion

This paper introduces a new dataset, Guide3D, for segmentation and 3D reconstruction of flexible, curved endovascular tools. Extensive experiments demonstrate the dataset’s value; yet several limitations must be acknowledged. Firstly, our dataset lacks clinical real human data due to the complexity and regulatory challenges of acquiring such data. Our standardized platform, however, aims to enable further research, providing a stepping stone towards clinical practice.

Additionally, the dataset primarily focuses on synthetic and experimental scenarios, which may not fully capture the variability and unpredictability of real-world clinical environments. While this controlled setting aids initial algorithm development and benchmarking, further validation with clinical data is necessary to ensure the robustness and generalizability of the proposed methods.

Moreover, the guidewire’s flexibility and the presence of loops and occlusions present significant challenges for segmentation and reconstruction tasks. Our dataset includes these complexities to push the boundaries of current methodologies, but future work should explore more advanced techniques.

Our dataset accommodates both video and image-based approaches, providing a versatile resource to facilitate the translation of these technologies into clinical settings. Our objective is to bridge the disparity between research developments and clinical application by establishing a standardized framework for evaluating the efficacy of various methodologies. Our code and dataset will be made publicly available.

Lightweight Language-driven Grasp Detection using Conditional Consistency Model (Part 1)

Language-driven grasp detection is a fundamental yet challenging task in robotics with various industrial applications. This work presents a new approach for language-driven grasp detection that leverages lightweight diffusion models to achieve fast inference time. By integrating diffusion processes with grasping prompts in natural language, our method can effectively encode visual and textual information, enabling more accurate and versatile grasp positioning that aligns well with the text query. To overcome the long inference time problem in diffusion models, we leverage the image and text features as the condition in the consistency model to reduce the number of denoising timesteps during inference. The intensive experimental results show that our method outperforms other recent grasp detection methods and lightweight diffusion models by a clear margin. We further validate our method in real-world robotic experiments to demonstrate its fast inference time capability.

Grasping Machine

1. Introduction

Grasping is one of the fundamental tasks in robotics, enabling robots to interact with the physical world through a broad spectrum of applications, from industrial automation and human-robot interaction to service robotics. Recent advancements in machine vision have significantly improved the capabilities of grasp detection for the robot. Prior research has demonstrated encouraging grasp detection results in both 2D images and 3D point clouds. However, most existing works define grasp detection as a region localization problem while ignoring the use of natural language to localize possible grasps on the object based on linguistic input.

Method Overview Figure 1: Virtual Demonstration of grasping a commanded object.

With the recent advances in Large Language Models (LLM), integrating language into robotic systems has become more popular. Pretrained models such as ChatGPT and CLIP have revolutionized various applications, and their adaptability to the robotic domain has shown encouraging results. Although several language-driven robotic manipulations work, most focus on understanding high-level actions and overlook the fundamental grasping task. In this paper, we tackle the language-driven grasp detection task that allows the robot to grasp specific objects based on the language command. With language-driven grasping ability, robots can interact more effectively with the surrounding environment and humans.

Language-driven grasping offers several advantages compared to the traditional grasp detection task without text. Firstly, we communicate with robots by providing language prompts that direct them to execute precise tasks; therefore, the incorporation of natural language instructions augments robotic systems with the ability to respond to dynamic, real-time tasks interactively. Secondly, using natural language addresses the challenge of ambiguity in identifying target objects within cluttered environments or distinguishing among objects with similar shapes. Lastly, linguistic guidance enriches robotic systems with semantic information, enhancing their learning capabilities without necessitating expert demonstrations or specific engineering.

Several works on grasp detection have recently utilized diffusion models as the essential technique and shown encouraging results. This is motivated by the proven efficacy of diffusion models in conditional generation tasks such as image synthesis, image segmentation, and visual grounding. The effectiveness of diffusion models comes from their iterative approach to gradually refine data from an initial state of pure noise toward a meaningful output. Nonetheless, applying diffusion models to language-driven tasks in robotics faces a key challenge, i.e., the inference time of diffusion models is usually not fast enough for real-time robotic applications. Consequently, recent studies have introduced techniques to tackle the inference speed problem of diffusion models using approaches such as rapid sampling, knowledge distillation, or model optimization. However, these models can still not perform fast sampling with language conditions during inference to meet the real-time requirement in robotic grasping.

In this paper, we propose a new lightweight diffusion model to tackle the inference speed problem in utilizing the diffusion model for the language-driven grasp detection task. To this end, we exploit the capabilities of flow-based generative models to improve the precision of robots in identifying grasp poses from textual inputs. In particular, we develop a conditional consistency model for fast inference speed for real-time robotic applications. We verify our proposed method on a recent large-scale language-driven grasping dataset and achieve superior accuracy and inference speed compared with recent approaches. Furthermore, our method enables zero-shot learning and generalizes it to real-world robotic grasping applications.

Our contributions are summarized as follows:

  • We present Lightweight Language-driven Grasp Detection (LLGD), a fast diffusion model for language-driven grasp detection.
  • We conduct intensive analysis to validate our method and demonstrate that it outperforms other approaches in terms of both accuracy and execution speed.

2. Related Works

Grasp Detection. Grasp detection has been a central topic in robotics, aiming to equip robots with the ability to identify and execute object grasping in complex environments. Several works have set the foundation for robot grasping by using convolutional neural networks (CNNs). Most previous grasp detection methods are often limited to simple tasks with a fixed number of classes and rely solely on raw image data. Several works have extended the problem by using RGB-D images or 3D point clouds to output the results in 3D space. However, they still have not focused on integrating language as the input instruction in the grasp detection problem.

Language-driven Grasping. Language-driven grasp detection introduces the use of natural language to inform grasp detection tasks. The standard approach to tackling the task of language-driven grasp detection is to divide it into a two-step process. One stage identifies the target object, and the second focuses on generating grasp poses based on the established visual-text correlations. Foundation models such as GroundDINO and CLIP have emerged, enabling zero-shot detection and segmentation. These models allow for the localization of the target object without training. However, due to their large size, they result in longer inference times. Accessing such commercial foundation models is not always possible, especially since LLM models often require using APIs, which come at a high cost.

Lightweight Diffusion Model. Lightweight diffusion models that maintain performance while reducing computational overhead have become crucial in machine learning. Researchers have utilized knowledge distillation for low-resolution features to reduce the number of parameters in U-Net. Recently, consistency models have surfaced as a robust approach of generative models capable of producing high-quality images within a single or a limited number of steps. Although there are significant applications in generative tasks, these models are primarily unconditional. On the other hand, robotic applications remain discriminative, making the use of unconditional diffusion models not entirely suitable. In this study, we address this issue by building a lightweight diffusion model with language conditions. We aim to enhance the consistency model work to inherit its fast inference time while adding the language conditions to make it more suitable for the language-driven grasping task.

Method Overview Figure 2: GraspNet Dataset, a widely used data for Grasp Detection.

Next

In the next post, we will introduce our proposal Lightweight Language-driven Grasp Detection using Conditional Consistency Model.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 3)

Utilizing the Guide3D dataset, we build a benchmark for the shape prediction task, a critical component in endovascular intervention. Accurate shape prediction of the guidewire is essential for successful navigation and intervention. Here, we introduce a novel shape prediction network designed to predict the guidewire shape from a sequence of monoplanar images. This approach leverages deep learning to learn spatio-temporal correlations from a static camera observing a dynamic scene. Unlike conventional reconstruction methods that require biplanar images, our network uses a sequence of images to extract temporal information, allowing it to map a single image IA\mathbf{I}_A to the 3D guidewire curve C(u)\mathbf{C}(\mathbf{u}). By adopting this deep learning approach, we aim to simplify the shape prediction process while maintaining high accuracy. This method has the potential to enhance endovascular navigation by providing real-time, accurate predictions of the guidewire shape, ultimately improving procedural outcomes and reducing reliance on specialized equipment.

Network Key Components: The figure illustrates the essential components of the proposed model. a) Spherical coordinates (r,θ,ϕ)(r, \theta, \phi) are used for predicting the guidewire shape. b) The model predicts the 3D shape of a guidewire from image sequences It\mathbf{I}_t. A Vision Transformer (ViT) extracts spatial features zt\mathbf{z}_t, which a Gated Recurrent Unit (GRU) processes to capture temporal dependencies, producing hidden states ht\mathbf{h}_t. The final hidden state drives three prediction heads: the Tip Prediction Head for the 3D tip position pR3\mathbf{p} \in \mathbb{R}^3, the Spherical Offset Prediction Head for coordinate offsets (Δϕ,Δθ)(\Delta \phi, \Delta \theta), and the Stop Prediction Head for terminal point probability S\mathbf{S}.

Spherical Coordinates Representation

Predicting 3D points directly can be challenging due to the high degree of freedom. To mitigate this, we use spherical coordinates, which offer significant advantages over Cartesian coordinates for guidewire shape prediction. Spherical coordinates, as represented in Fig. 1a, are defined by the radius rr, polar angle θ\theta, and azimuthal angle ϕ\phi. They provide a more natural representation for the position and orientation of points along the guidewire, which is typically elongated and curved.

Mathematically, a point in spherical coordinates (r,θ,ϕ)(r, \theta, \phi) can be converted to Cartesian coordinates (x,y,z)(x, y, z) using the transformations:

x=rsinθcosϕ,y=rsinθsinϕ,z=rcosθ.x = r \sin \theta \cos \phi, \quad y = r \sin \theta \sin \phi, \quad z = r \cos \theta.

This conversion simplifies the modeling of angular displacements and rotations, as spherical coordinates directly encode directional information.

Predicting angular displacements (Δθ,Δϕ)(\Delta \theta, \Delta \phi) relative to a known radius rr aligns with the physical constraints of the guidewire, facilitating more accurate and interpretable shape predictions. By predicting an initial point (tip position) and representing subsequent points as offsets in Δϕ\Delta \phi and Δθ\Delta \theta while keeping rr fixed, this method simplifies shape comparison and reduces the parameter space. This approach enhances the model’s ability to capture the guidewire’s spatial configuration and improves overall prediction performance.

Network Architecture

The proposed model (shown in Fig. 1b) addresses the problem of predicting the 3D shape of a guidewire from a sequence of images. Each image sequence captures the guidewire from different time steps IA,t\mathbf{I}_{A,t}, and the goal is to infer the continuous 3D shape Ct(ut)\mathbf{C}_t(\mathbf{u}_t). This many-to-one prediction task is akin to generating a variable-length sequence from variable-length input sequences, a technique commonly utilized in fields such as machine translation and video analysis.

To achieve this, the input pipeline consists of a sequence of images depicting the guidewire. A Vision Transformer (ViT), pre-trained on ImageNet, is employed to extract high-dimensional spatial feature representations from these images. The ViT generates feature maps ztR\mathbf{z}_t \in \mathbb{R}. These feature maps are then fed into a Gated Recurrent Unit (GRU) to capture the temporal dependencies across the image sequence. The GRU processes the feature maps zt\mathbf{z}_t from consecutive time steps, producing a sequence of hidden states ht\mathbf{h}_t. Formally, the GRU operation at time step tt is defined as:

ht=GRU(zt,ht1).\mathbf{h}_t = \text{GRU}(\mathbf{z}_t, \mathbf{h}_{t-1}).

The final hidden state ht\mathbf{h}_t from the GRU is used by three distinct prediction heads, each tailored for a specific aspect of the guidewire shape prediction: the Tip Prediction Head, responsible for predicting the 3D coordinates of the guidewire’s tip through a fully connected layer that maps the hidden state ht\mathbf{h}_t to a Cartesian anchoring point pR3\mathbf{p} \in \mathbb{R}^3; the Spherical Offset Prediction Head, which predicts the spherical coordinate offsets (Δϕ,Δθ)(\Delta \phi, \Delta \theta) for points along the guidewire with a fixed radius rr; and the Stop Prediction Head, which outputs the probability distribution indicating the terminal point of the guidewire by using a softmax layer to produce a probability tensor S\mathbf{S}, where each element Sj\mathbf{S}_j indicates the probability of the jj-th point being the terminal point.

Loss Function

The custom loss function for training the model combines multiple components to handle the point-wise tip error, variable guidewire length (stop criteria), and tip position predictions. The overall loss function Ltotal\mathcal{L}_{\text{total}} is defined as:

Ltotal=1Ni=1N(λtipp^ipi2+λoffset((ϕ^iϕi)2+(θ^iθi)2)+λstop(silog(s^i)(1si)log(1s^i)))\mathcal{L}_{\text{total}} = \frac{1}{N} \sum_{i=1}^N \bigg( \lambda_{\text{tip}} \left \| \hat{\mathbf{p}}_i - \mathbf{p}_i \right \|^2 + \lambda_{\text{offset}} \big( (\hat{\boldsymbol{\phi}}_i - \boldsymbol{\phi}_i)^2 + (\hat{\boldsymbol{\theta}}_i - \boldsymbol{\theta}_i)^2 \big) + \lambda_{\text{stop}} \big( -\mathbf{s}_i \log (\hat{\mathbf{s}}_i) - (1 - \mathbf{s}_i) \log (1 - \hat{\mathbf{s}}_i) \big) \bigg)

where NN is the number of samples, and λtip\lambda_{\text{tip}}, λoffset\lambda_{\text{offset}}, and λstop\lambda_{\text{stop}} are weights that balance the contributions of each loss component. The tip prediction loss (Ltip\mathcal{L}_{\text{tip}}) uses mean squared error (MSE) to ensure accurate 3D tip coordinates. The spherical offset loss (Loffset\mathcal{L}_{\text{offset}}) also uses MSE to align predicted and ground truth angular offsets, capturing the guidewire’s shape. The stop prediction loss (Lstop\mathcal{L}_{\text{stop}}) employs binary cross-entropy (BCE) to accurately predict the guidewire’s endpoint.

Training Details

The model was trained end-to-end using the loss from Equation above. The NAdam optimizer was used with an initial learning rate of 1×1041 \times 10^{-4}. Additionally, a learning rate scheduler was employed to adjust the learning rate dynamically based on the validation loss. Specifically, the ReduceLROnPlateau scheduler was configured to reduce the learning rate by a factor of 0.1 if the validation loss did not improve for 10 epochs. The model was trained for 400 epochs, with early stopping based on the validation loss to further prevent overfitting.

Next

In the next part, we will validate the effectiveness of Guidewire Shape Prediction dataset and methodology.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 2)

We propose the Guid3D Dataset, a comprehensive resource specifically designed to advance 3D reconstruction and segmentation in endovascular navigation. This dataset addresses key limitations in the field, such as the scarcity of high-quality, publicly accessible datasets, by providing a diverse collection of real and synthetic imaging data. Guid3D includes detailed annotations for guidewire and catheter segmentation, alongside multi-view fluoroscopic data that supports accurate 3D modeling. By offering a standardized platform for algorithm development and evaluation, Guid3D aims to bridge the gap between research and clinical practice, facilitating improvements in precision, visualization, and tool tracking during endovascular procedures. Through this dataset, we seek to accelerate innovation in medical imaging, contributing to safer and more effective interventions.

Data Collection Setup

X-ray System. Our setup employed a bi-planar X-ray system equipped with 60 kW Epsilon X-ray generators and 16-inch image intensifier tubes by Thales, featuring dual focal spot Varian X-ray tubes for high-definition imaging. The system included Ralco automatic collimators for precise alignment and exposure, with calibration achieved through the use of acrylic mirrors and geometric alignment grids.

Anatomical Models. We utilized a half-body vascular phantom model from Elastrat Sarl Ltd., Switzerland, enclosed in a transparent box and integrated into a closed water circuit to simulate blood flow. Made from soft silicone and equipped with compact continuous flow pumps, it replicates human blood flow dynamics. The design is based on detailed postmortem vascular casts, ensuring anatomical accuracy reflective of human vasculature, facilitating realistic vascular simulations.

Figure 1. Dataset Overview: Guide3D contains 8,746 manually annotated frames from two views for 3D reconstruction (left), from which the reconstruction is derived (right).

Surgical Tools. To enhance our dataset, we navigated complex vascular structures using two types of guidewires commonly used in real-world endovascular surgery. The first, the Radifocus™ Guide Wire M Stiff Type (Terumo Ltd.), is made from nitinol with a polyurethane-tungsten coating. It measures 0.89 mm in diameter and 260 cm in length, with a 3 cm angled tip, designed for seeking, dissecting, and crossing lesions. The second, the Nitrex Guidewire (Nitrex Metal Inc.), also made of nitinol, features a gold-tungsten straight tip for enhanced radiopacity in fluoroscopic visualization. It has a diameter of 0.89 mm and a length of 400 cm, with a 15 cm tip, and is generally used for accessing or maintaining position during catheter exchanges. Both guidewires were selected to reflect real-world usage and to diversify the data in our dataset.

Figure 2. Materials: a) Overall setup & endovascular phantom, b) Radifocus (angled) guidewire. and c) Nitrex (straight) guidewire.

Data Acquisition, Labeling, and Statistics

Using the materials described in Subsection 3.1, we compiled a dataset of 8,746 high-resolution samples (1,024 × 1,024 pixels). This dataset includes 4,373 paired instances, both with and without a simulated blood flow medium. Specifically, it consists of 6,136 samples from the Radifocus guidewire and 2,610 from the Nitrex guidewire, providing a solid foundation for automated guidewire tracking in bi-planar scanner images. Manual annotation was carried out using the Computer Vision Annotation Tool (CVAT), where polylines were created to accurately track the dynamic path of the guidewires. The polyline representation was chosen because the guidewire's structure often results in overlapping sections, making a segmentation mask unsuitable. In contrast, a polyline effectively captures the looping nature of the guidewire, offering greater accuracy.

As shown in Table 1, the dataset includes 3,664 instances of angled guidewires with fluid and 484 without, while straight guidewires are represented by 2,472 instances with fluid and 2,126 without. This distribution reflects a variety of procedural contexts. All 8,746 images in the dataset are accompanied by manual segmentation ground truth, facilitating the development of algorithms that require segmentation maps as reference data.

Table 1. Dataset Composition Overview.

Calibration

We extract the camera parameters using a traditional undistortion and calibration method. Undistortion is first achieved with a local weighted mean (LWM) algorithm, using a perforated steel sheet with a hexagonal pattern as a framing reference, and applying a blob detection algorithm to precisely identify distortion points. This approach establishes correspondences between distorted and undistorted positions, allowing for accurate distortion correction.

Following this, a semi-automatic calibration step is performed for marker identification, and the random sampling consensus (RANSAC) method is used to ensure robustness in computing the projection matrix and deriving the intrinsic and extrinsic camera parameters. The calibration process is further refined through direct linear transformation (DLT) and non-linear optimization, utilizing multiple poses of the calibration object to optimize the overall camera setup. Figure 3 illustrates the calibration process.

Figure 3. Fluoroscopic Calibration: a) Undistortion grid application, and b) Point identification on calibration frame.

Guidewire Reconstruction

Given polyline representations of a curve in both planes, the reconstruction process begins by parameterizing these curves using B-Spline interpolation. Each curve is expressed as a function of the cumulative distance along its path. Let CA(uA)\mathbf{C}_A(\mathbf{u}_A) and CB(uB)\mathbf{C}_B(\mathbf{u}_B) represent the parameterized B-Spline curves in their respective planes, where uA\mathbf{u}_A and uB\mathbf{u}_B are the normalized arc-length parameters. The corresponding uB\mathbf{u}_B for a given uA\mathbf{u}_A is found using epipolar geometry. Once the corresponding points CA(uAi)\mathbf{C}_A(\mathbf{u}_A^i) and CB(uBi)\mathbf{C}_B(\mathbf{u}_B^i) are identified, their 3D coordinates Pi\mathbf{P}^i are computed by triangulation, resulting in a set of 3D points {Pi}i=1M\{\mathbf{P}^i\}_{i=1}^{M}, where MM is the total number of sampled points. This effectively reconstructs the original curve in 3D space.

To retrieve the fundamental matrix F\mathbf{F}, which describes the relationship between points in Image A (IA\mathbf{I}_A) and Image B (IB\mathbf{I}_B), the condition xBTFxA=0\mathbf{x}_B^T \mathbf{F} \mathbf{x}_A = 0 must hold for corresponding points xA\mathbf{x}_A in IA\mathbf{I}_A and xB\mathbf{x}_B in IB\mathbf{I}_B. Using the projection matrices PA\mathbf{P}_A and PB\mathbf{P}_B derived from the calibration process, the fundamental matrix can be calculated as follows:

F=[eB]×PBPA+\mathbf{F} = [\mathbf{e}_B]_\times \mathbf{P}_B \mathbf{P}_A^+

Here, eB\mathbf{e}_B is the epipole in Image B, defined as eB=PBCA\mathbf{e}_B = \mathbf{P}_B \mathbf{C}_A, with CA\mathbf{C}_A being the camera center of PA\mathbf{P}_A. The skew-symmetric matrix of the epipole eB\mathbf{e}_B is represented by:

[eB]×=[0eB3eB2eB30eB1eB2eB10][\mathbf{e}_B]_\times = \begin{bmatrix} 0 & -e_{B3} & e_{B2} \\ e_{B3} & 0 & -e_{B1} \\ -e_{B2} & e_{B1} & 0 \end{bmatrix}

Where eB=(eB1,eB2,eB3)T\mathbf{e}_B = (e_{B1}, e_{B2}, e_{B3})^T, and PA+\mathbf{P}_A^+ is the pseudoinverse of the projection matrix PA\mathbf{P}_A. The fundamental matrix F\mathbf{F} encapsulates the epipolar geometry between the two views, ensuring that corresponding points xA\mathbf{x}_A and xB\mathbf{x}_B lie on their respective epipolar lines.

The matching phase begins by uniformly sampling points along the curve CA(uA)\mathbf{C}_A(u_A) at intervals ΔuA\Delta u_A. For each sampled point xA=CA(uA)x_A = \mathbf{C}_A(u_A), we project the epiline lB=FxAl_B = F x_A into Image B. We then determine the intersection of the epiline lBl_B with the curve CB(uB)\mathbf{C}_B(u_B), thereby obtaining the corresponding parameter uBu_B for each uAu_A.

Due to errors in the projection matrices PAP_A and PBP_B, there are instances where the epiline lBl_B does not intersect with any part of the curve CB\mathbf{C}_B. To address this, we fit a monotonic function fA(uA)uBf_A(u_A) \rightarrow u_B using a Piecewise Cubic Hermite Interpolating Polynomial (PCHIP), thus interpolating the missing intersections. The matching process is visualized in Fig. 4.

Figure 4.Point Matching Process. Sampled points from image IAI_A (CA(uA)\mathbf{C}_A(u_A)) and their corresponding epilines lAl_A on image IBI_B are matched with their counterparts CB(uB)\mathbf{C}_B(u_B). The epilines for CB(uB)\mathbf{C}_B(u_B) are then computed and displayed on image IAI_A.

Utility of Guide3D Dataset for the Research Community

Guide3D advances endovascular imaging by providing a bi-planar fluoroscopic dataset for segmentation and 3D reconstruction, serving as an open-source benchmark. It enables precise algorithm comparisons for segmentation and facilitates method development in 3D reconstruction through the use of bi-planar imagery. With video data, Guide3D supports video-based methods, leveraging temporal dimensions for dynamic analysis. This enriches the segmentation and reconstruction capabilities, while also aligning with the procedural nature of endovascular interventions. This versatility highlights Guide3D's pivotal role in advancing endovascular imaging.

Next

In the next part, we will explore Guidewire Shape Prediction methodology.

Guide3D A Bi-planar X-ray Dataset for 3D Shape Reconstruction (Part 1)

Endovascular surgical tool reconstruction represents an important factor in advancing endovascular tool navigation, which is an important step in endovascular surgery. However, the lack of publicly available datasets significantly restricts the development and validation of novel machine learning approaches. Moreover, due to the need for specialized equipment such as biplanar scanners, most of the previous research employs monoplanar fluoroscopic technologies, hence only capturing the data from a single view and significantly limiting the reconstruction accuracy.

To bridge this gap, we introduce, a bi-planar X-ray dataset for 3D reconstruction. The dataset represents a collection of high resolution bi-planar, manually annotated fluoroscopic videos, captured in real-world settings. Validating our dataset within a simulated environment reflective of clinical settings confirms its applicability for real-world applications. Furthermore, we propose a new benchmark for guidewrite shape prediction, serving as a strong baseline for future work. The proposal not only addresses an essential need by offering a platform for advancing segmentation and 3D reconstruction techniques but also aids the development of more accurate and efficient endovascular surgery interventions.

Introduction

Minimally invasive surgery has revolutionized endovascular interventions, offering less invasive options with shorter recovery times. The success of these procedures depends on the precise navigation and manipulation of instruments such as guidewires and catheters. Typically, 2D visualization methods are used for guidance, with monoplanar fluoroscopy being the most common due to its minimal disruption to surgical workflows and relatively affordable cost. However, despite their widespread use, conventional imaging techniques have significant limitations, with one of the primary challenges being the lack of depth perception. This issue complicates the accurate visualization of surgical instruments, increasing the risk of excessive contact with arterial walls, which can compromise patient safety and the effectiveness of the procedure.

In endovascular interventions, depth perception is largely achieved through multi-view imaging systems, such as biplanar scanners, which allow shape reconstruction by combining images from multiple angles and employing epipolar geometry-based reconstruction. However, two major challenges hinder the broader adoption and effectiveness of these systems: (i) the difficulty of accurately segmenting images for successful shape reconstruction, exacerbated by the scarcity of datasets needed to evaluate segmentation methods, and (ii) the limited availability of specialized biplanar scanners in clinical settings due to their high cost. These challenges underscore the critical need for comprehensive datasets to enhance segmentation algorithm accuracy and improve guidewire reconstruction techniques, facilitating the development of more versatile imaging technologies.

Figure 1. Guide3D dataset contains 8,746 manually annotated frames from two views for 3D reconstruction.

In this paper, we introduce Guid3D, a dataset designed to advance 3D reconstruction in endovascular navigation. Guid3D provides a standardized platform for the development and evaluation of algorithms. With a comprehensive dataset that includes manual annotations for segmentation and tools for effective 3D visualization, Guid3D is intended to drive innovation and improvement in endovascular intervention. Furthermore, the inclusion of video-based biplanar fluoroscopic data allows for the exploration of temporal dynamics, such as using optical flow networks. Guid3D seeks to bridge the gap between research innovations and clinical applications, addressing key challenges in endovascular procedures.

Related Works

Endovascular Datasets.

Datasets play a crucial role in advancing endovascular navigation by providing essential resources for the development, evaluation, and enhancement of algorithms. These datasets, derived from various imaging modalities such as mono X-ray, 3D ultrasound, and 3D MRI, encompass both real and synthetic images, facilitating diverse applications in the medical field.

Mono X-ray datasets, while prevalent, often fall short in providing the necessary detail required for accurate 3D reconstruction, which is critical for effective surgical navigation. The inherent limitations of 2D imaging techniques make it challenging to fully capture the complexity of anatomical structures during procedures. In contrast, 3D imaging modalities like 3D ultrasound and 3D MRI offer more comprehensive views, enabling better depth perception and improved visualization of surgical tools and surrounding tissues.

Despite the importance of these datasets, there remains a significant gap in the availability of comprehensive, publicly accessible datasets specifically designed for tool segmentation and 3D reconstruction. This scarcity hampers progress in developing robust algorithms capable of accurately interpreting complex medical images. The lack of diverse and high-quality datasets also limits the ability of researchers to train and validate their algorithms effectively, often leading to suboptimal performance in clinical scenarios.

Furthermore, creating high-quality datasets is not merely a technical challenge; it requires collaboration among various stakeholders, including clinicians, radiologists, and data scientists. Such collaboration is essential to ensure that the datasets reflect real-world clinical conditions and include diverse patient populations. Expanding the availability of well-annotated datasets is vital for fostering innovation and advancing the field of endovascular surgery.

Figure 2. Endovascular Dataset Explaination.

Catheter and Guidewire Segmentation.

The segmentation of endovascular tools, particularly guidewires and catheters, is an evolving field that heavily relies on the availability and quality of datasets. Previous studies have often used synthetic and semi-synthetic data to address the challenges posed by the limited availability of real-world datasets. Researchers have employed manually annotated datasets from 2D X-ray and 3D MRI modalities to train segmentation models. Additionally, the effectiveness of synthetic datasets has been demonstrated in improving model efficiency.

The advent of deep learning techniques, especially U-Net architectures, has significantly enhanced the accuracy of segmentation and tracking for these surgical instruments. This advancement has led to the development of fully automated segmentation frameworks that utilize extensively annotated data and incorporate unsupervised techniques, such as optical flow. However, the absence of a public, standardized dataset for method comparison continues to impede the advancement and assessment of scientific progress in this area.

Figure 3. Interventional Microcatheters.

3D Reconstruction.

Improving the accuracy of 3D reconstruction in endovascular procedures plays a crucial role in achieving better clinical outcomes by enhancing catheter navigation through advanced visualization and precise tracking. Advances in fluoroscopic imaging technology have led to more accurate positioning of devices. Various algorithms have been developed to facilitate this process, employing techniques such as elastic grid registration and epipolar geometry for 3D reconstruction from biplane angiography. Additionally, automatic catheter detection methods utilizing triangulation and graph-search algorithms have been applied in electrophysiology studies to improve reconstruction outcomes.

Research has demonstrated the importance of accurate 3D models for navigation within both complex and single-view vascular architectures, highlighting the value of biplanar data. However, the limited availability of comprehensive, publicly accessible datasets for the development and validation of algorithms in 3D reconstruction poses a significant challenge to technological progress and clinical application. This situation underscores the critical need for specialized datasets to promote ongoing innovation in the reconstruction of endovascular tools.

Figure 4. Guideware Calibration.

Next

In the next part, we will dive deeply into how to conduct a dataset for 3D shape reconstruction.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 4)

In the previous part, we introduce training process and experimental setup. In this part, we validate the effectiveness and efficiency of the proposed method.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Experimental Results

Quality Comparison

Table 1 presents a comparison among the baselines FACT, Transflower, EDGE, GDanceR, GCD, and our proposed manifold-based method. The results clearly demonstrate that our model outperforms the baselines across all evaluations on two datasets, AIOZ-GDANCE and AIST-M. We observe that recent diffusion-based dance generation models, such as EDGE or GCD, achieve competitive performance on both single-dance metrics (FID, MMC, GenDiv, and PFC) and group dance metrics (GMR, GMC, and TIF). However, due to limitations in their training procedures, they still struggle with generating multiple dancing motions when faced with a large number of dancers, as indicated by their lower performance compared to our method. This suggests that our approach successfully maintains the quality of dance motions as the number of dancers increases.

Table 1. Performance comparison.
Additionally, Figure 2 illustrates that our proposed method outperforms other state-of-the-art models like GDanceR and GCD in addressing issues such as monotonous, repetitive, sinking, and overlapping dance motions.
Figure 2. Visualization of a dancing sample between different methods. GDanceR displays monotonous, repetitive, or sinking dance motions. GCD exhibits more divergence in dance motions, yet dancers may intersect since their optimization does not address this issue explicitly. Blue boxes mark these issues. In contrast, our manifold-based solution ensures the divergence of dancing motions, while the phase motion path demonstrates its effectiveness in addressing floating and crossing issues in group dances.

Scalable Generation Analysis

Table 2 illustrates the performance of different group dance generation methods (GCD, GDancer, and Ours) when generating dance movements with an increasing number of dancers. When the number of dancers is increased to 10, GCD appears to run out of memory, which is also observed in GDanceR when the number of dancers increases to 100. Regardless of the number of dancers, our method consistently achieves stable and competitive results. This implies that our proposed method successfully addresses the scalability issue in group dance generation without compromising the overall performance of each individual dance motion.

Table 2. Performance of group dance generation methods when we increase the number of generated dancers. The experiments are done with common consumer GPUs with 4GB memory. (N/A means models could not run due to inadequate memory footprint).

Figure 3 illustrates the memory consumption to generate dance motions in groups for each method. Noticeably, our proposal still achieves the highest performance while consuming immensely fewer resources required for generating group dance motions (See Figure 4 for illustrations). This, again, indicates that our method successfully breaks the barrier of limited generated dancers by using the manifold.

Figure 3. Memory usage vs. number of dancers in different dance generators.

Figure 4. Visualization of a dancing sample between different methods. GDanceR displays monotonous, repetitive, or sinking dance motions. GCD exhibits more divergence in dance motions, yet dancers may intersect since their optimization does not address this issue explicitly. Blue boxes mark these issues. In contrast, our manifold-based solution ensures the divergence of dancing motions, while the phase motion path demonstrates its effectiveness in addressing floating and crossing issues in group dances.

Ablation Study

Table 3 presents the performance improvements achieved through the integration of consistency loss and phase manifold. Additionally, we showcase the effectiveness of our proposed approach across three different backbones: Transformer, LSTM, and CNN. Evaluation metrics including FID, GMR, and GMC are utilized. The results indicate that the absence of consistency loss leads to an increase in GMR and a decrease in GMC, suggesting a significant enhancement in the realism and correlation of group dance motions facilitated by the inclusion of the proposed objective. Meanwhile, with out the phase manifold, the model exhibits remarkably higher scores in both the FID and GMR metrics, suggesting that phase manifold can effectively improve the distinction in dance motions while maintaining the realism of group dances, even when the number of dancers in a group is large. In comparing three backbones—Transformer, LSTM, and CNN—we have observed that the chosen Transformer backbone achieved the best results compared to LSTM or CNN.

Table 3. Module contribution and loss analysis.

User Study

User studies are vital for evaluating generative models, as user perception is pivotal for downstream applications; thus, we conducted two studies with around 70 participants each, diverse in background, with experience in music and dance, aged between 20 to 40, consisting of approximately 47\% females and 53\% males, to assess the effectiveness of our approach in group choreography generation.

In the user study, we aim to assess the realism of sample outputs with more and more dancers. Specifically, participants assign scores ranging from 0 to 10 to evaluate the realism of each dance clip with 2 to 10 dancers. Figure 5 shows that, across all methods, the more the number of dancers is increased, the lower the realism is found. However, the drop in realism of our proposed method is the least compared to GCD and GDanceR. The results indicate our method's good performance compared to other baselines when the number of dancers increases.

Figure 5. Realism between different methods when number of dancers is varied.

Discussion

While our approach leverages the VAE as a primary solution for generating a manifold, it is important to acknowledge certain inherent limitations associated with this choice. One notable challenge is the susceptibility to issues such as posterior collapse and unstable sampling within the VAE framework. These challenges can result in generated group dance motions that may not consistently meet performance expectations.

One specific manifestation of this limitation is the potential for false decoding when sampling points that lie too far from the learned distribution. This scenario can lead to unexpected rotations or disruptions in the physics of the generated content. The impact of this problem becomes evident in instances where the generated samples deviate significantly from the anticipated distribution, introducing inaccuracies and distortions.

To address these challenges, we recognize the need for ongoing efforts to mitigate the effects of posterior collapse and unstable sampling. While the problem is acknowledged, our approach incorporates measures to limit its impact. Future research directions could explore alternative generative models or additional techniques to enhance the robustness and reliability of the generated results in the face of these identified limitations.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 8)

Welcome to the concluding part of our series! In this final installment, we'll delve deeper into the intricacies of the search system that accompanies our proposal. Additionally, we'll conduct a comprehensive evaluation to gauge the overall effectiveness and performance of our solution. Let's explore the various aspects of the search functionality and analyze how it contributes to the success of our proposed system. Let's embark on this journey together as we wrap up our discussion and reflect on the insights gained throughout this series.

1. Searching and Indexing Method in the CFOR System

To adapt our retrieval system for large-scale datasets, we've developed indexes for the CFOR system to facilitate nonexhaustive similarity search using GPU acceleration. Leveraging the searching algorithm outlined by Johnson et al. (referenced as "billion-scale similarity search with GPUs"), we've implemented it within the CFOR system for retrieval tasks. In searching, the CFOR system enhances accuracy by reducing the search space through additional information such as regions, categories, and attributes. For indexing, the object ontology aids in creating multi-indexing files to minimize search time. Our focus is on similarity search in vector collections using the L2 distance in the k-selection algorithm.

In the realm of searching, we distinguish between exact search (exhaustive search) and compressed search (greedy nonexhaustive search).

  • Exact search: Almost all searching algorithms in this type try to compute the full pairwise distance between the query and each data point in the database sequentially or using the index file.
  • Compressed-Domain search: Almost all searching algorithms in this type try to compute distance between the query and each data point in the database by applying space transformation, encoding, subspace splitting, or hashing. These methods can help improve searching time by using index files, but they have a trade-off in searching accuracy.

2. Data

Our fashion retrieval system was built on a subset of approximately 300,000 images of DeepFashion. In the DeepFashion dataset, objects from different aspects are caught in complicated background. The input image in the dataset is annotated with different labels based on details (fine-grained) of input of the current model concern, i.e., rich annotation. The samples given in Figures 1 and 2 show more details about the DeepFashion dataset.

Figure 1. Images from the DeepFashion dataset obtained from different views and complicated background.

Figure 2. Images from the DeepFashion dataset annotated with different labels based on details of input of the current model concern.

In testing, we employ part of the benchmark data to fine tune the trained models. We ensure that there are no fashion item overlaps between fine-tuning and testing sets. The dataset includes ∼220,000 images of the training set, 40,000 images of the validating set, and 40,000 images of the testing set split. However, in attribute learning, we limited the number of attribute labels used for testing and the number of training images for specific attributes to make an imbalanced attribute dataset so as to prove our proposed methods.

3. Results and Discussion

In the CFOR system, object ontology is useful in controlling training flow which impacts the performance of object category classification and attribute multitask classification. For object category classification, ontology controls the amount of training data through concepts. For attribute multitask classification, ontology manages local grouping which directly affects the performance of the proposed local imbalanced data solver on the large-scale dataset.

In this section, we will evaluate the effectiveness of different deep networks with the support of ontology on both category classification and attribute multitask classification in the CFOR system to pick out the best architecture for training the system. We will also compare our results with FashionNet.

Category Classification

We compare the performance between different deep architectures including NASNet, ResNet-18, ResNet-101, FashionNet, NASNet with average pooling dropout (NASNet APD) (proposed by us), and ResNet with average pooling dropout (ResNet APD) (proposed by us). These experiments will be evaluated by top-k accuracy (Figure.3). Our target is to find out the best possible architecture to apply as a core network of the CFOR system. This step can be mentioned as a preparation step before applying the CFOR system for fashion retrieval.

Figure 3. Accuracy plot for top-k accuracy in category classification.

The result of category classification by ResNet-18 APD is higher than 1.23% (at k= 1) after removing nodes and making average pooling in the ResNet-18 architecture (compared with the original ResNet-18 architecture). This increased value is 0.93% with the ResNet-101 architecture (compared with the original ResNet-101 architecture) and 0.02% with the NASNet v3 architecture (compared with the original NASNet v3 architecture). The ResNet-101 APD architecture (the best architecture addressed) outperformed the FashionNet architecture (the best performing architecture in category classification on the DeepFashion dataset versus others such as WTBI or DARN), and the value is 4.6% with k= 3 and 2.58% with k = 5. Based on the above experimental results, the ResNet-101 architecture provides better classification and higher performance compared to others (NASNet and ResNet-18). For this reason, we propose ResNet-101 as the core network architecture for training classification models.

Attribute Learning

Attribute multitask learning is an important part of the CFOR system. In this section, we evaluate the performance of the proposed local imbalanced data solver with MCC in dealing with the imbalanced attribute data on the large-scale fashion dataset.

Precision is the proportion of relevant instances among the retrieved instances which consider both true positives and false positives in each attribute. However, the number of true positives and false positives is bias because of the imbalanced data problem. Thus, precision can also be affected by the imbalanced data problem. Otherwise, recall, which cares about true-positive labels but not false-positive labels, will be used to evaluate experiments because of its good reflection for fewer data attributes.

Local MTL gets over STL and MTL in 28/35 attributes with a 54.70% recall rate (higher than that in STL (17.06%) and that in MTL (28.70%)). While a single task shows its weakness in fewer data attributes and multitasks get struggled with the serious imbalanced problem and lesser intergroup correlations in fashion data, local MTL can lower their negative influences as well as widen the positive effect of inner-group correlations on attribute learning. Thus, local MTL gets over STL and MTL in 13/15 fewer sample attributes (Figure.4).

Figure 4.Recall graph of 14 attributes in STL and local MTL.

Based on the experiment, comparison of chic, solid, and maxi attributes which have equal accuracy between MTL with and without MCC shows that MTL with MCC had higher recall compared to that without MCC in 20/35 remaining attributes. The overall performance increases about 3%. For attributes with fewer data, MTL with MCC had higher recall compared to that without MCC in 9/14 attributes. The overall performance for these fewer data attributes increases 5.14% (see Figure 5 for more details).

Figure 5.Recall graph of 35 attributes using local multitask models with and without MCC.

Retrieval in CFOR System Results

In this experiment, we test the retrieval ability of the CFOR system by using MAP from 1 retrieval result for each query (MAP@1) to 30 retrieval results for each query (MAP@30) so as to evaluate the effectiveness. The similarity retrieval experiment will check whether the extracted attributes in retrieved images are matched with ground-truth attributes in query image. The retrieval method will be based on deep features and over 35 attributes. After experimenting in 35 attributes belonging to 5 groups, the starting MAP@5 is acceptable (hovering 0.531) which shows the effectiveness of the searching method. The MAP@30 hovers 0.815, and the trend keeps rising which shows consistency and stabilization of information prediction methods in the CFOR system. A simple visualization of the retrieval process in the CFOR system is shown in Figure.6.

Figure 6. An example of the retrieval results of the CFOR system.

4. Conclusion

This work presents the coarse-to-fine object retrieval system, a learning framework for e-commerce online retrieval, which is supported to deal with large-scale imbalanced datasets. The framework can impact input and output as well as reconstruct datasets from the coarse-grained level to the fine-grained level and is believed to be an effective method in improving learning performance designed for retrieval. For input reconstruction, the framework based on ontology is used for threading training flow, local grouping in multitask attribute learning, and hierarchical storage and retrieval. For output optimization, we take advantage of MCC to minimize the effect of the imbalanced dataset on multitask attribute learning.

Through extensive experiments, we demonstrate the applicability of object ontology in improving training flow, the effectiveness of different deep networks (ResNet and NASNet) applied on important tasks in fine-grained retrieval, and the usefulness of local multitask attribute learning and an MCC-based imbalanced data solver in attribute multitask learning. The CFOR system is designed to have flexibility so that it can be optimized easily in the future.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 3)

In the previous part, we introduct our main proposal Variational Phase Manifold Learning. In this part, we have explore the training process and experimental setup to validate the proposed method.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Training

During training, we consider the following variational lower bound to mainly train our dance generation VAE model:

logpθ(xa)Eqϕ[logpθ(xz,a)]DKL(qϕ(zx,a)pθ(za))\log p_\theta(\mathbf{x}|\mathbf{a}) \geq \mathbb{E}_{q_\phi} \left[ \log p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}) \right] - D_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}) \Vert p_\theta(\mathbf{z}|\mathbf{a}))

In practice, we apply the conditional VAE loss, which is defined as the weighted sum Lcvae=Lrec+λKLLKL\mathcal{L}_{\text{cvae}} = \mathcal{L}_{\text{rec}} + \lambda_{\text{KL}}\mathcal{L}_{\text{KL}}. In particular, the reconstruction term Lrec\mathcal{L}_{\text{rec}} measures the motion reconstruction error given the decoder output (via a smooth-L1 loss). The KL divergence term LKL\mathcal{L}_{\text{KL}} compares the divergence DKLD_{\text{KL}} between the posterior and the prior distribution to enforce them to be close to each other.

The conditional VAE objective above is calculated for each dancer separately and cannot capture the correlation between dancers within a group. Therefore, it is essential to impose consistency among dancers and avoid strange effects such as unsynchronized dance. To this end, we propose a group consistency loss by enforcing the latent phase manifold to be similar for the same group, given the input music. Specifically, we first calculate the phase manifold features based on the frequency domain parameters as follows:

P2i1=Aisin(2πSi),P2i=Aicos(2πSi)\mathbf{P}_{2i-1} = \mathbf{A}_i\sin(2\pi \cdot \mathbf{S}_i), \qquad \mathbf{P}_{2i} = \mathbf{A}_i\cos(2\pi\cdot \mathbf{S}_i)

where PR2D\mathbf{P}\in\mathbb{R}^{2D} is the phase manifold vector that encodes the spatial-temporal information of the motion state. The phase feature may look similar to the positional encodings of transformers in the sense that both use multi-resolution sinusoidal functions. However, the phase feature is much richer in terms of representation capacity since it learns to embed the spatial (body joints) and temporal (positions in time) information of the motion curves, whereas the positional encodings only encode the position of words. Finally, our consistency objective is to constrain phase manifold between dancers within a group to be as close as possible to each other, which is formulated as:

Lcsc=DKL(qϕ(zxm,a)(qϕ(zxn,a))+PmPn22\mathbf{\mathcal{L}}_{\text{csc}} = D_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x}^m,\mathbf{a}) \Vert (q_\phi(\mathbf{z}|\mathbf{x}^n,\mathbf{a}) ) + \Vert \mathbf{P}^m - \mathbf{P}^n\Vert^2_2

where the first term encourages the network to map different dancers belonging to the same group (xm\mathbf{x}^m and xn\mathbf{x}^n) into the same set of distributional phase parameters while the second term penalizes the discrepancy in their corresponding phase manifolds. In general, this loss is applied to ensure every dancer is embedded into a single unified manifold that can effectively represent their corresponding group. To summarize, our total training loss is defined as the combination of the VAE loss and the consistency loss L=Lcvae+λcscLcsc\mathcal{L} = \mathbf{\mathcal{L}}_{\text{cvae}} + \lambda_{\text{csc}}\mathbf{\mathcal{L}}_{\text{csc}}.

For testing, we can efficiently generate motions for an unlimited number of dancers by indefinitely drawing samples from the learned continuous group-consistent phase manifold. It is noteworthy that for inference, we only need to process the prior network once to obtain the latent distribution. To generate a new motion, we can sample from this latent (Gaussian) distribution and use the decoder to decode it back to the motion space. This approach is much more efficient and has significantly higher scalability than previous approaches that is limited by the number of dancers processed simultaneously by the entire network.

Figure 2. An example output.

Experiments

Implementation Details

Our model was trained on 4 NVIDIA V100 GPUs using Adam optimizer with a fixed learning rate of 10410^{-4} and a mini-batch size of 32 per GPU. For training losses, the weights are empirically set to λKL=5×104\lambda_{\text{KL}} = 5\times 10^{-4} and λcsc=104\lambda_{\text{csc}} = 10^{-4}, respectively. The Transformer encoders and decoders consist of 5 layers for both encoder, decoder, and prior Network with 512-dimensional hidden units. Meanwhile, the number of latent phase channels is set to 256. To further capture the periodic nature of the phase feature, we also use Siren activation following the initialization scheme. This can effectively model the periodicity inherent in the motion data, and thus can benefit motion synthesis.

Experimental Settings

Dataset. In our experiments, we utilize the AIOZ-GDance and AIST-M datasets. AIOZ-GDance is the largest music-driven dataset focusing on group dance, encompassing paired music and 3D group motions extracted from in-the-wild videos through a semi-automatic process. This dataset spans 7 dance styles and 16 music genres. For consistency, we adhere to the training and testing split during our experiments.

Evaluation Protocol. We employ several metrics to assess the quality of individual dance motions, including Frechet Inception Distance (FID), Motion-Music Consistency (MMC, and Generation Diversity (GenDiv), along with the Physical Foot Contact score (PFC). Specifically, the FID score gauges the realism of individual dance movements concerning the ground-truth dance. MMC assesses the matching similarity between motion and music beats, reflecting how well-generated dances synchronize with the music's rhythm. GenDiv is computed as the average pairwise distance of kinetic features among motions. PFC evaluates the physical plausibility of foot movements by determining the agreement between the acceleration of the character's center of mass and the foot's velocity.

In assessing the quality of group dance, we adopt three metrics : Group Motion Realism (GMR), Group Motion Correlation (GMC), and Trajectory Intersection Frequency (TIF). Broadly, GMR gauges the realism of generated group motions in comparison to ground-truth data, employing Frechet Inception Distance on extracted group motion features. GMC evaluates the synchronization among dancers within the generated group by computing their cross-correlation. TIF quantifies the frequency of collisions among the generated dancers during their dance movements.

Baselines. Our method is subjected to comparison with various recent techniques in music-driven dance generation, namely FACT, Transflower, and EDGE. These approaches are adapted for benchmarking within the context of group dance generation, considering that their original designs were tailored for single-dance scenarios. Additionally, our evaluation includes a comparison with GDanceR, GCD, and DanY. All of the mentioned works are specifically designed for the generation of group choreography.

Figure 3. Generated Group Dance from GCD baselines

Next

In the next part, we will explore the effectiveness of the proposal through quantitative and qualitative results.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 7)

To provide fine-grained information to the CFOR system, attribute learning is a most important task which should be optimized in both time-processing performance and ability to deal with large-scale imbalanced datasets.

1. Framework

Local multitask learning is applied to attribute learning. The proposed framework, depicted in Figure.1 and Figure.2 and comprising online and offline phases, consists of three main components. The initial component introduces a local multitask transfer learning model with a loss function designed to leverage inner-group correlations among attributes. The second component presents an imbalanced data resolver based on MCC (Matthews Correlation Coefficient) without any adjustments to the pretrained model or loss function. The third component discusses prior knowledge used for local attribute grouping to facilitate local multitask learning.

Figure 1. Local MTL with an imbalanced data problem solver framework (Offline phase).

The input and output of the learning framework will be images and their attribute vectors, respectively. However, with the local grouping role, the attribute vector’s size will be based on the number of attributes in each group. The dataset should be merged or split based on the local grouping role.

To evaluate the effectiveness of the proposed framework, we apply it in the fashion field and split the dataset into five local groups: fabric, shape, part, style, and texture. Because fashion has lesser intergroup correlations, the shared block should be designed to optimize the effectiveness of inner-group correlations to improve the overall performance. However, in crowd attributes (such as activities, locations, and participants), intergroup correlations should be taken into account to improve performance. Thus, the shared block should be modified to adapt to the context.

Figure 2. Local MTL with an imbalanced data problem solver framework (Online phase).

2. Deep Multitask Learning

Our aim is to estimate a number of fashion attributes via a joint estimation model. However, with the dynamic attributes, MTL which supports creating a joint estimation model becomes vulnerable in the training phase due to its nonusability when the number of attributes increases. Thus, the local grouping method can help solve this situation.

Framework in Detail

In the experiments, the proposed framework processes the query image and generates a confident score vector comprising 7 attribute scores per group across 5 groups, which is then thresholded to obtain binary outputs. The architecture is outlined in detail below.

Figure.1 illustrates the overall structure of the proposed method. For each group, a training set is assumed, consisting of NN fashion images, each with MM attributes. The dataset is represented as D=(Xi,Yi)D = {(X_i, Y_i)}, where XiX_i is the image and YiY_i is the label encoded as a one-hot vector. Inspired by prior researches, we employ an end-to-end DNN architecture as a shared block to learn joint representations for all tasks. The loss function employed is binary cross-entropy, and the activation function at the output layer is sigmoid, chosen for its simplicity and flexibility in modifying the DNN architecture.

Network Architecture

NASNet automatically generates network architectures, constructing an optimal model by initially creating architectures on a smaller dataset and then scaling them up to a larger one. Through experiments, the search for the best cells is conducted on the CIFAR-10 dataset, which are subsequently applied to the ImageNet dataset by stacking multiple copies of them, each with their own parameters. The resulting model demonstrates a 1.2% improvement in top-1 accuracy compared to the best human-designed architectures. NASNet proves its effectiveness over previous architectures and offers a transfer learning model trained on a diverse ImageNet dataset. Leveraging the pretrained NASNet model on ImageNet, transfer learning is applied to the DeepFashion dataset to expedite convergence and enhance performance. Additionally, a dropout layer is incorporated with NASNet to mitigate overfitting. While utilizing the NASNet model generation algorithm to tailor a model for the DeepFashion dataset is a promising approach, the time and hardware resources required for NASNet's model generation and training from scratch are significant. Due to hardware limitations, only transfer learning is employed.

Figure 3. Best normal cells and reduction cells identified with CIFAR-10 and ImageNet architecture (right) are built from the best convolutional cells . Zoph et al. built two types of cells because they want to create architectures for images of any size. While normal cells return a feature map which has the same dimension, reduction cells return a feature map with height and width reduced by a factor of two.

We will do experiments on NASNet architectures to find out which one is suitable for each specific task in our CFOR system. In our fashion retrieval experiments, the category classifier task and region classifier task are applied transfer with single-task learning, while fashion attribute recognition is applied local multitask learning. Besides, to adapt to large-scale datasets and reduce the effect of overfitting, we recommend changing the final fully connected layer to the global average pooling layer along with dropout.

Next

In the next post, we will discover Searching and Indexing Method in the CFOR System as well as the effectiveness of the whole system.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 2)

In the previous part, we have explore the introduction about group dance scalability and the what is manifold. In this part, we introduct our main proposal Variational Phase Manifold Learning.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

Task Definition

Given an input music sequence a={a1,at,...,aT}\mathbf{a} = \{a_1, a_t, ...,a_T\} with t={1,...,T}t = \{1,..., T\} indicates the index of the music frames, our goal is to generate the group motion sequences of NN arbitrary dancers: x={x11,...,xT1;...;x1N,...,xTN}\mathbf{x} = \{x^1_1,..., x^1_T; ...;x^N_1,...,x^N_T\} where xtnx^n_t is the pose of nn-th dancer at frame tt. We use the 6D continuous rotation for every joint, along with 3D joint positions and velocities. Additionally, the corresponding 3D root translation vectors are concatenated into the pose representations to involve the trajectory of motion. Previous group dance methods, which generate the whole group at once, cannot deal with the increasing number of dancers and can only create group sequences up to a pre-defined number of dancers, due to the vast complexity of the architecture. In contrast, we aim to generate group dance with an unlimited number of dancers.

Figure 2. An example output.

Phase-conditioned Dance VAE

Our goal is to learn a continuous manifold such that the motion can be generated by sampling from this learned manifold. We assume that although different dancers within the same group may present visually distinctive movements, the properties of their motions, such as timing, periodicity, or temporal alignment are intrinsically similar. We aim to learn a generative phase representation for each group of dancer in order to synthesize their motion indefinitely. Our generative model is built upon the conditional Variational Autoencoder architecture, thanks to its diverse generation capability and fast sampling speed. However, instead of directly encoding the data into a Gaussian latent distribution as in common VAE approaches, we model the latent variational distribution by the phase parameters extracted from the latent motion curve, which we call variational phase manifold. The latent phase manifold is well-structured and can well describe key characteristics of motion (such as its timing, local periodicity, and transition), which benefits learning motion features.

The overview of our Phase-conditioned Dance VAE is illustrated in Figure 3. Specifically, the model contains three main networks: an encoder E\mathcal{E} to capture the approximate posterior distribution conditioned on both motion and music qϕ(zx,a)q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}), a prior network P\mathcal{P} to learn the conditional prior given only the music pθ(za)p_\theta(\mathbf{z}|\mathbf{a}) , and a decoder D\mathcal{D} to learn to reconstruct the data from the latent distribution pθ(xz,a)p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}). The new motion is generated by sampling the frequency-domain parameters predicted by the prior network, which is then passed through the decoder network to reconstruct the motion in the original data space. Furthermore, we adopt Transformer-based architecture in each network to effectively capture long-range dependencies and holistic context of the dance sequence.

Figure 3. Overview of our Phase-conditioned Dance VAE (PDVAE) for scalable group dance generation. It consists of an Encoder, a Prior, and a Decoder network. During training, we encode the corresponding motion and music inputs into a latent phase manifold, which is variational and parameterized by the frequency domain parameters of periodic functions. The latent phases can be sampled from the manifold and then decoded back to the original data space to obtain new motions. The consistency loss Lcsc\mathcal{L}_{\text{csc}} is further imposed to constrain the manifold to be consistently unified for dancers that belong to the same group. At inference stage, only the Prior and the Decoder are used to synthesize group dances efficiently. .

Encoder

The encoder E\mathcal{E} is expected to take both the motion and music feature sequence as input, and produce a distribution over possible latent variables capturing the cross-modal relationship between them. To transform the joint input space into a learned phase manifold, we adopt the Transformer decoder architecture where the Cross-Attention mechanism is utilized to learn the relationship between the motion and the music. Accordingly, the output of the encoder is a batch of latent curves (i.e., the activation sequences per channel) that can particularly capture different spatial and temporal aspects of the motion sequence. However, instead of training the model to directly reconstruct the input motion from the extracted latent curves, we further enforce each channel of the latent space to have a periodic functional form (i.e., sinusoidal). This enables us to effectively learn a compact parameterization for each latent channel from a small set of parameters in the frequency domain.

Generative Variational Phase Manifold

Here we focus on learning the periodicity and non-linear temporal alignment of the motion in the latent space. In particular, given the output latent curves from the encoder L=E(x,a)RD×T\mathbf{L} = \mathcal{E}(x,a) \in \mathbb{R}^{D \times T} with DD is the number of desired phase channels to be extracted from the motion, we parameterize each latent curve in L\mathbf{L} using a sinusoidal function with amplitude (A\mathbf{A}), frequency (F\mathbf{F}), offset (B\mathbf{B}) and phase shift (S\mathbf{S}) parameters. To allow for variational phase manifold learning, we opt to predict two sets of parameters μE={μA;μF;μB;μS}\mathbf{\mu}_{\mathcal{E}} =\{\mathbf{\mu}^A; \mathbf{\mu}^F; \mathbf{\mu}^B; \mathbf{\mu}^S \} and σE={σA;σF;σB;σS}\mathbf{\sigma}_{\mathcal{E}} =\{\mathbf{\sigma}^A; \mathbf{\sigma}^F; \mathbf{\sigma}^B; \mathbf{\sigma}^S \}, which corresponds to the mean and variance of R4D\mathbb{R}^{4D} dimensional Gaussian distribution:

qϕ(zx,a)=N(z;μE,σE)q_\phi(\mathbf{z}|\mathbf{x},\mathbf{a}) = \mathcal{N}(\mathbf{z};\mathbf{\mu}_{\mathcal{E}}, \mathbf{\sigma}_{\mathcal{E}})

To do so, we first apply differentiable Fast Fourier Transform (FFT) to each channel of the latent curve L\mathbf{L} and create the zero-indexed matrix of Fourier coefficients as c=FFT(L)\mathbf{c}=FFT(\mathbf{L}) with cCD×K+1\mathbf{c} \in \mathbb{C}^{D \times K+1}, K=T2K =\lfloor \frac{T}{2}\rfloor. Correspondingly, we compute the per channel power spectrum pRD×K+1\mathbf{p} \in \mathbb{R}^{D \times K+1} as pi,j=2Nci,j2\mathbf{p}_{i,j} = \frac{2}{N}|\mathbf{c}_{i,j}|^2, where ii is the channel index and jj is the index for the frequency bands. Correspondingly, the distributional mean parameters of the periodic sinusoidal function are then calculated as follows:

μiA=2Tj=1Kpi,j,μiF=j=1Kfjpi,jj=1Kpi,j,μiB=ci,0T,\mathbf{\mu}^A_i = \sqrt{\frac{2}{T}\sum_{j=1}^K \mathbf{p}_{i,j}}, \quad \mathbf{\mu}^F_i = \frac{\sum_{j=1}^K \mathbf{f}_j \cdot \mathbf{p}_{i,j}}{ \sum_{j=1}^K \mathbf{p}_{i,j}}, \quad \mathbf{\mu}^B_i = \frac{\mathbf{c}_{i,0}}{T},

where f=(0,1T,,KT)\mathbf{f} = (0, \frac{1}{T},\dots,\frac{K}{T}) is the frequencies vector. At the same time, the phase shift S\mathbf{S} is predicted using a fully-connected (FC) layer with two arctan\arctan activation as:

(sy,sx)=FC(Li),μiS=arctan(sy,sx),(s_y, s_x) = \text{FC}(\mathbf{L}_i), \quad \mathbf{\mu}^S_i = \arctan(s_y,s_x),

To predict the distributional variance of the phase amplitude and phase shift parameters {σA,σS}\{\mathbf{\sigma}^A, \mathbf{\sigma}^S\}, We additionally apply a separate two-layer MLP network over each channel of the latent curves. The variational latent phase parameters are sampled by utilizing parameterization trick, i.e., AN(μA,σA)\mathbf{A}\sim\mathcal{N}(\mathbf{\mu}^A,\mathbf{\sigma}^A) and SN(μS,σS)\mathbf{S}\sim\mathcal{N}(\mathbf{\mu}^S,\mathbf{\sigma}^S). In our experiments, we find that sampling the phase frequency F\mathbf{F} and offset B\mathbf{B} often produce unstable and non-coherent group movements. This might be because the frequency amplitudes of the dancers within the same group are likely to associate with the rhythmic pattern of the musical beats while the offsets capture their alignment, thereby should be consistent with each other. Therefore, we treat those parameters as deterministic by constraining their variance to zero.

Finally, the sampled set of phase parameters z={A;F;B;S}\mathbf{z} = \{\mathbf{A};\mathbf{F};\mathbf{B};\mathbf{S}\} are used to reconstruct a parametric latent space consisting of multiple periodic curves to represent each intrinsic property of the motion by:

L^=Asin(2π(FTS))+B\hat{\mathbf{L}} = \mathbf{A} \cdot \sin (2\pi \cdot (\mathbf{F}\cdot\mathcal{T} - \mathbf{S})) + \mathbf{B}

where T\mathcal{T} is a known time window series obtained by evenly spacing the timesteps from 00 to TT. Intuitively this curve construction procedure can be viewed as a "quantization" layer to enforce the network to learn to represent the motion features in the frequency domain, which is useful in representing different aspects of human motion such as their timing and periodicity. In the last step, a decoder is utilized to reconstruct the original motion signals from the set of parametric latent curves.

Figure 4. Manifold conduction.

Decoder

To decode the latent space into the original motion space, previous works have to use a sinusoidal positional encoding sequence with duration TT as the proxy input to the sequence decoder. This is because their latent space is only formed by single latent vectors following a Gaussian distribution, which cannot span the time dimension. However, we observe that it usually results in unstable and inconsistent movements, as the proxy sequence is generic and usually contains less meaningful information for the decoder. Meanwhile, our method does not suffer from this problem as our latent space is built on multiple curves that can represent the motion information through time, thanks to the phase parameters. Subsequently, our decoder D\mathcal{D} is based on Transformer decoder architecture that takes the constructed parametric latent curve, as well as the music features as inputs, to reconstruct the corresponding dance motions. Here, we also utilize the cross-attention model where we consider the sequence of and music features as key and value along with the sampled latent curves as the query. The output of the decoder is a sequence of TT vectors in RD\mathbb{R}^D, which is then projected back to the original motion dimensions through a linear layer, to obtain the reconstructed outputs x^=pθ(xz,a)\hat{\mathbf{x}}=p_\theta(\mathbf{x}|\mathbf{z},\mathbf{a}). We additionally employ a global trajectory predictor to predict the global translation of the root joint based on the generated local motions, in order to avoid intersection problems between dancers.

Prior Network.

Since the ground-truth motion is generally inaccessible at test time (i.e., we only have access to the music), we also need to learn a prior P\mathcal{P} to match the posterior distribution of motion from which the latent phase can be sampled. Specifically, We follow the manifold procedure to predict the Gaussian distribution conditioned on the music sequence a\mathbf{a}, which is then used for sampling the latent phases:

pθ(za)=N(z;μP,σP)p_\theta(\mathbf{z}|\mathbf{a}) = \mathcal{N}(\mathbf{z};\mathbf{\mu}_{\mathcal{P}}, \mathbf{\sigma}_{\mathcal{P}})

where a Transformer encoder is used to encode the input conditioning music sequence and predict the corresponding μP\mathbf{\mu}_{\mathcal{P}} and σP\mathbf{\sigma}_{\mathcal{P}}. We implement the prior network similarly to the encoder network, however, we use self-attention mechanism to capture the global music context. Learning the conditional prior is crucial for the conditional VAE to generalize to diverse types of music and motion. Intuitively speaking, each latent variable z\mathbf{z} is expected to represent possible dance motions x\mathbf{x} conforming to the music context a\mathbf{a}. Therefore, the prior should be able to encode different latent distributions given different musics.

Next

In the next part, we will explore the training procedure and experimental setups to validate the effectiveness of the proposed method.

Large-Scale Coarse-to-Fine Object Retrieval Ontology and Deep Local Multitask Learning (Part 6)

In this section, we will mention about ontology, fashion ontology, and its related information and present the contributions of object ontology to the CFOR system.

1. Ontology Definition for CFOR System

As described by Guarino, ontology is a "formal, explicit specification of a shared conceptualization." Typically, ontologies consist of concepts and their hierarchical structure, aiding in organizing information within a domain. A complete ontology typically includes concepts, relations, and axioms. Additionally, ontologies offer several key advantages:

  • Describing domain knowledge through a semantic hierarchical tree, with concepts represented as nodes identified by words or phrases.
  • Bridging the semantic gap in various tasks, including those in computer vision and other disciplines.
  • Enhancing software engineering practices by improving flexibility, reliability, specification, and reusability.
  • Supporting multitask problem-solving capabilities.

Any proposed ontology should satisfy two fundamental criteria:

  • Wide recognition within the community.
  • Feasibility for formalization using mathematical expressions, enabling digitization.

In our approach, we employ ontological engineering to facilitate communication and information sharing across different levels of data abstraction involved in image fashion retrieval, detection, and information tagging.

The object ontology comprises two primary levels: coarse-grained and fine-grained.

  • At the coarse level, the object ontology includes regions, categories, or high-level conceptual entities, which leverage global features extracted by deep networks for similarity retrieval. However, these deep features are treated as black boxes, lacking explicit semantic information to aid users in their search process.
  • At the fine-grained level, the object ontology encompasses attributes that provide detailed descriptions of objects.

In our experiment, we focus on describing the object "Fashion." The fashion ontology is constructed using prior knowledge and information from the DeepFashion dataset, along with ontology definitions introduced by Guarino. See Figure.1 for an illustration of the fashion ontology.

Figure 1. Fashion ontology in general and a version of ontology for clothes.

The fashion ontology developed comprises three primary semantic levels: 1. Regions: Representing areas such as Top, Bottom, and Body. 2. Categories: Specific objects associated with each region, such as dresses or robes for the Body region. 3. Attributes: Describing detailed visual concepts like denim or fur.

To streamline the discussion, our investigation focuses on the object fashion across three regions (Top, Body, and Bottom), select categories within these regions, and their respective attributes.

Within the CFOR system, a query image undergoes processing starting from the coarse level of the object ontology to identify the region and category of the corresponding object. Subsequently, the object proceeds to the fine-grained concept ontology to ascertain attributes. Once all necessary information is obtained, the object undergoes indexing and similarity distance computation to identify similar images in the database, ranked by a cumulative score derived from similarity scores of global features and attribute learning between the query image and target database images. For a detailed illustration, refer to Figure.2.

Figure 2. An example of a relationship between the query image and semantic information from the coarse-grained level to the fine-grained level of the fashion ontology.

2. Fashion Object Ontology

In this section, we introduce the fashion object ontology. Within the fashion domain, we categorize semantic fashion concepts based on regions. Each region encompasses a detailed ontology comprising categories and attributes. To facilitate experiments using the DeepFashion dataset, we extend the fashion ontology within the "Clothes" branch (refer to Figure 9). It's essential to emphasize that the proposed ontology is not specific to any application and should be viewed as a flexible foundation.

The fashion object ontology consists of multiple levels of concepts, with relations between each level to articulate their associations. Two primary relations are employed: 1. "Part of": This relation specifies that the concepts are components of the main concept. 2. "Has a": This relation describes the main concept in detail.

For this study, we concentrate solely on the Clothes branch to ensure equitable comparisons with other methodologies. The Clothes taxonomy comprises 50 distinct categories. A clothing region taxonomy has been established (refer to Figure.3), organizing all clothing categories hierarchically. The first level of this hierarchy represents the most general clothing region, with three primary regions defined: 1. Top (e.g., tee and tank) 2. Bottom (e.g., skirt and jeans) 3. Body (e.g., dress and robe)

Figure 3. Excerpt from the “Clothes” taxonomy defined in the fashion ontology.

3. Fine-Grained Object Ontology

Fine-grained object ontology is used to describe objects at the attribute level. Semantic information such as attributes can be useful for a customer to retrieve (see Figure.4). It is important to note that the proposed ontology is not application dependent and should be considered as an extensible basis.

Figure 4. Fine-grained group at the attribute level.

Cloth attributes vary across different levels—some attributes, like color, are common across all cloth regions, while others are specific to certain regions or categories. Our ontology is structured into two main parts, each detailed in the following sections: 1. Specific fashion concepts—pertaining to particular characteristics of clothes such as fabric, part, and style. 2. Visual concepts—related to popular visual characteristics like color, shape, and texture, not exclusive to fashion.

Rudd et al. demonstrated in a study that a multitask learning-based model outperforms a combination of single-task learning-based models in face attribute prediction. While this approach shows promising results for fashion attributes as well, there's a significant difference in the quantity of attributes between faces and fashion items. This disparity can pose challenges in scaling the system, such as in training and storage requirements. To address this, we propose applying local multitask learning to attribute learning, providing more flexibility. Further explanation is provided in the subsequent sections.

Next

In the next post, we will discover Attribute Learning and its correlation with multitask learning.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 3)

In the previous article, we delved into integrating the contrastive divergence loss function into federated learning, exploring its potential benefits for enhancing model performance and tackling non-IID data issues in autonomous driving contexts. In this follow-up piece, we delve into various federated learning configurations and present experimental findings that support the efficacy of our approach.

1. Implementation

Dataset: Our experimentation involves three datasets (see Table 1): Udacity+, Gazebo Indoor, and Carla Outdoor. Gazebo and Carla datasets exhibit non-IID characteristics, while Udacity+ represents a non-IID variant of the Udacity dataset.

DatasetTotal samplesAverage samples in each silo (Gaia)Average samples in each silo (NWS)Average samples in each silo (Exodus)
Udacity+38,5863,5081,754488
Gazebo66,8066,0733,037846
Carla73,2356,6583,329927

Table 1: The Statistic of Datasets in Our Experiments.

Network Topology: Our experimentation encompasses three distinct federated topologies, namely the Internet Topology Zoo (Gaia), North American data centers (NWS), and the Zoo Exodus network (Exodus). The primary focus is on the Gaia topology, while supplementary insights from the NWS and Exodus topologies are provided in our ablation study.

Training: Within each silo, model training is executed with a batch size of 32 and a learning rate set at 0.001, facilitated by the Adam optimizer. The local training regimen within each silo precedes the transmission and aggregation of models using the specified global aggregation equation. The training regimen spans 3,600 communication rounds and leverages a simulation environment akin to that described in Nguyen et al. (2022), powered by an NVIDIA 1080 GPU.

Baselines: Our comparative analysis involves several contemporary methods across diverse learning scenarios, including Random and Constant baselines as outlined by Loquercio et al. (2018). Within the Centralized Local Learning (CLL) scenario, we utilize Inception-V3, MobileNet-V2, VGG-16, and Dronet as baseline models. In the context of Server-based Federated Learning (SFL), our comparison extends to FedAvg, FedProx, and STAR. For the Decentralized Federated Learning (DFL) setting, our evaluation includes MATCHA, MBST, and FADNet. Model effectiveness is assessed using Root Mean Square Error (RMSE) and Mean Absolute Error (MAE), while wall-clock time (ms) serves as a metric for training duration.

2. Qualitative Results

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

ModelMain FocusLearning MethodRMSE (Udacity+)RMSE (Gazebo)RMSE (Carla)MAE (Udacity+)MAE (Gazebo)MAE (Carla)# Training ParametersAvg. Cycle Time (ms)
Random__0.3580.1170.4640.2650.0870.361__
ConstantStatistical_0.3110.0920.3480.2090.0670.232__
InceptionArchitecture DesignCLL0.2090.0850.2970.1970.0620.20721,787,617_
MobileNetArchitecture DesignCLL0.1930.0830.2860.1760.0570.2002,225,153_
VGG-16Architecture DesignCLL0.1900.0830.3160.1610.0500.1847,501,587_
DroNetArchitecture DesignCLL0.1830.0820.3330.1500.0530.218314,657_
FedAvgAggregation OptimizationSFL0.2120.0940.2690.1850.0640.222314,657152.4
FedProxAggregation OptimizationSFL0.1520.0770.2260.1180.0630.151314,657111.5
STARAggregation OptimizationSFL0.1790.0620.2080.1490.0530.155314,657299.9
MATCHATopology DesignDFL0.1820.0690.2080.1480.0580.215314,657171.3
MBSTTopology DesignDFL0.1830.0720.2140.1490.0580.206314,65782.1
FADNetTopology DesignDFL0.1620.0690.2030.1340.0550.197317,72962.6
CDL (ours)Loss OptimizationCLL0.1690.0740.2660.1490.0530.172629,314_
CDL (ours)Loss OptimizationSFL0.1500.0600.2080.1040.0520.150629,314102.2
CDL (ours)Loss OptimizationDFL0.1410.0620.1830.0830.0520.147629,31472.7

Table 2: Performance comparison between different methods. The Gaia topology is used.

Table above summarizes the performance comparison between our proposed method and recent state-of-the-art approaches. The results indicate that our CDL under the Siamese setup with two ResNet-8 models outperforms other methods by a significant margin. Notably, our approach achieves substantial reductions in both RMSE and MAE across all three datasets: Udacity+, Carla, and Gazebo. Despite not increasing the network's parameter count, CDL introduces a larger model size during training due to the additional sub-network required by the Siamese setup. Additionally, our CDL with ResNet-8 demonstrates superior performance compared to other baselines, particularly in the DFL learning scenario, and to a lesser extent in SFL and CLL setups.

3. Contrastive Divergence Loss Analysis

CDL Performance Across Various Topologies In practice, training federated algorithms becomes more complex as the topology involves more vehicle data silos. To assess the efficacy of our CDL, we conduct training experiments and compare the results with other baseline methods across different topologies. Table 3 presents the performance comparison of DroNet, FADNet, and our CDL with a ResNet-8 backbone when trained using DFL across three distributed network infrastructures with varying numbers of silos: Gaia (11 silos), NWS (22 silos), and Exodus (79 silos). The table clearly demonstrates that our CDL consistently achieves superior results across all topology configurations. In contrast, DroNet encounters divergence issues, and FADNet exhibits suboptimal performance, particularly in the Exodus topology with 79 silos.

TopologyArchitectureUdacity+GazeboCarla
Gaia (11 silos)DroNet0.177 (↓0.036)0.073 (↓0.011)0.244 (↓0.061)
FADNet0.162 (↓0.021)0.069 (↓0.007)0.203 (↓0.020)
CDL (ours)0.1410.0620.183
NWS (22 silos)DroNet0.183 (↓0.045)0.075 (↓0.017)0.239 (↓0.057)
FADNet0.165 (↓0.027)0.070 (↓0.012)0.200 (↓0.018)
CDL (ours)0.1380.0580.182
Exodus (79 silos)DroNet0.448 (↓0.310)0.208 (↓0.147)0.556 (↓0.380)
FADNet0.179 (↓0.041)0.081 (↓0.020)0.238 (↓0.062)
CDL (ours)0.1380.0610.176

Table 3: Performance under different topologies.

CDL with Various Architectures Our CDL, functioning as a loss function, exhibits versatility across different network architectures when integrated into the Siamese setup, leading to performance enhancements. Figure.1 showcases the efficacy of CDL across diverse networks such as DroNet, FADNet, Inception, MobileNet, and VGG-16 within the Gaia Network framework in the DFL scenario. The outcomes demonstrate CDL's efficacy in mitigating the non-IID challenge across varied architectures, consistently elevating performance.

Figure 1. Performance of CDL under different networks in Siamese setup.

CDL with IID Data Figure.2 illustrates CDL's efficacy across various data distributions. While CDL is primarily tailored for addressing the non-IID challenge, it also demonstrates marginal performance enhancements when applied to models trained on IID data distributions. Leveraging the Siamese setup, CDL inherits traits and behaviors akin to triplet loss. Given triplet loss's established effectiveness with IID data, it becomes evident that CDL can similarly augment model performance in scenarios where IID data is utilized.

Figure 2. Performance of different methods on IID dataset (Udacity) and non-IID dataset (Udacity+).

3. Ablation Study

Figure.3 illustrates the training results in RMSE of our two baselines, DroNet and FADNet, as well as our proposed CDL. The results showcase the convergence ability of the mentioned methods across three datasets (Udacity+, Gazebo, Carla) with NWS and Gaia topology. It is evident that our proposed CDL can attain a superior convergence point compared to the two baselines. While other methods (DroNet and FADNet) struggled to converge or exhibited poor convergence trends, our proposed CDL demonstrates better overcoming of local optimal points and shows less bias towards any specific silo.

Figure 3. The convergence ability of different methods under Gaia topology (top row) and NWS topology (bottom row).

4. Discussion

Although our method exhibits promising results, there are several areas for potential improvement that warrant consideration for future work:

  • Our CDL is specifically designed to address the non-IID problem, meaning its effectiveness heavily depends on the presence of non-IID characteristics within the autonomous driving data. Thus, in scenarios where data distributions across vehicles are relatively consistent or lack significant non-IID factors, the proposed contrastive divergence loss may only lead to limited performance enhancements.

  • While our proposal has been validated on autonomous driving datasets, real-world testing on actual vehicles has not yet been conducted. Conducting a study involving various driving scenarios, such as interactions with pedestrians, cyclists, and other vehicles, could provide further validation of our method's efficacy.

  • Although our approach is tailored for autonomous driving applications, its underlying principles may have applicability in other domains where non-IID data is prevalent. Exploring its effectiveness in areas like healthcare, IoT, and industrial settings could expand its potential impact.

Conclusion

We presented a new method to address the non-IID problem in federated autonomous driving using contrastive divergence loss. Our method directly reduces the effect of divergence factors in the learning process of each silo. The experiments on three benchmarking datasets demonstrate that our proposed method performs substantially better than current state-of-the-art approaches. In the future, we plan to test our strategy with more data silos and deploy the trained model using an autonomous vehicle on roads.

Scalable Group Choreography via Variational Phase Manifold Learning (Part 1)

Generating group dance motion from the music is a challenging task with several industrial applications. Although several methods have been proposed to tackle this problem, most of them prioritize optimizing the fidelity in dancing movement, constrained by predetermined dancer counts in datasets. This limitation impedes adaptability to real-world applications. Our study addresses the scalability problem in group choreography while preserving naturalness and synchronization.

Figure 1. We present a new group dance generation method that can generate a large number of dancers within a fixed resource consumption. The illustration shows a generated group dance sample with 100100 dancers.

In particular, we propose a phase-based variational generative model for group dance generation on learning a generative manifold. Our method achieves high-fidelity group dance motion and enables the generation with an unlimited number of dancers while consuming only a minimal and constant amount of memory. The intensive experiments on two public datasets show that our proposed method outperforms recent state-of-the-art approaches by a large margin and is scalable to a great number of dancers beyond the training data.

Introduction

The proliferation of digital social media platforms has significantly increased the popularity of creating and sharing dance videos. This surge in interest has led to the daily production and consumption of millions of dance videos across various online platforms, attracting attention from the research community in fields such as computer vision and computer graphics. Recent advancements in these areas have focused on generating authentic dance movements in response to music, with broad applications spanning animation, virtual idols, virtual metaverses, and dance education. These technologies provide powerful tools for artists, animators, and educators, enhancing creativity and enriching the dance experience for both performers and audiences.

Figure 2. An example visualization of a virtual group dancer.
While considerable progress has been made in generating solo dance motions, creating synchronized and realistic group dance performances remains a complex and unresolved challenge. Existing methods typically face limitations in scalability, either generating dances for a fixed number of performers or suffering from high memory consumption due to architectural constraints. These approaches, often based on collaborative mechanisms such as cross-entity or global attention, struggle to scale up for larger groups of dancers, limiting their practical applicability. Moreover, the reliance on predefined datasets with a fixed number of dancers further restricts these models from being adapted to real-world scenarios requiring larger group choreographies.

To address these challenges, we propose a novel approach to scalable group dance generation using a phase-based variational generative model, termed Phase-conditioned Dance VAE (PDVAE). Unlike traditional variational autoencoders that operate in high-dimensional motion space and often struggle to capture temporal dynamics, PDVAE leverages phase parameters in the frequency domain to represent the latent motion space. This enables the generation of realistic and synchronized group dance performances without increasing computational and memory costs, even as the number of dancers grows. PDVAE provides a flexible and efficient solution for generating crowd-scale dance animations, with potential applications in diverse fields such as entertainment, virtual reality, education, and media production.

In this work, we aim to advance the state-of-the-art in scalable group dance generation by overcoming the limitations of existing methods and demonstrating the feasibility of generating large-scale, natural, and synchronized dance performances efficiently.

To summarize, our key contributions are as follows: - We introduce PDVAE, a phase-based variational generative model for scalable group dance generation. The method focuses on generating large-scale group dance under limited resources. - To effectively learn the manifold that is group-consistent (i.e., dancers within a group lie upon the same manifold), we propose a group consistency loss that enforces the networks to encode the latent phase manifold to be identical for the same group given the input music. - Extensive experiments along with thorough user study evaluations demonstrate the state-of-the-art performance of our model while achieving effective scalability.

Related Works

Music-driven Choreography

Crafting natural human choreography derived from music poses a complex challenge, encompassing the need for synchronization, coherence, and expressiveness between movement and musical inputs. Earlier approaches often relied on music-motion similarity constraints to ensure alignment between the generated dance and the accompanying music. Many of these methods employ heuristic algorithms to stitch together pre-existing dance segments from limited music-dance databases, successfully producing extended and realistic dance sequences. However, such techniques are constrained when attempting to generate novel dance fragments, as they depend heavily on pre-defined segments rather than innovative motion generation .

Recent advancements have focused on utilizing deep learning architectures to map music into dance motions. Various techniques, including Convolutional Networks (CNN) , Recurrent Networks (RNN) , Graph Neural Networks (GNN) , Generative Adversarial Networks (GAN) , and Transformer models , have been explored for dance generation. These models typically rely on inputs such as the current music and a brief history of previous dance motions to predict the next human poses in a sequence. For instance, innovations in multi-modal feature fusion have enabled the simultaneous integration of music and text, producing dance sequences guided by both musical and textual cues .

Despite the progress made by these methods, generating synchronized and coordinated dance movements for multiple dancers remains a significant challenge. Achieving harmony between multiple dancers requires not only temporal alignment but also the consideration of spatial relationships and interactions between performers. This adds complexity to the generation process, making it difficult for current techniques to manage multiple dancers cohesively. Additionally, these methods are often constrained by the limited number of dancers present in their training datasets .

Figure 3. Multi-person Tracking with GNN.

Several recent works have sought to address these limitations. For instance, Perez et al. propose a multimodal transformer combined with a normalizing-flow-based decoder to predict a distribution of possible future poses, offering improved flexibility in motion generation. Feng et al. introduce a motion repeat constraint for long-term generation, allowing their model to generate future frames while taking into account historical dance motions. Le et al. further explore group dance by examining consistency and diversity among dancers, ensuring that generated movements maintain coherence across multiple performers. However, these methods are still limited by the number of dancers depicted in training datasets, restricting their scalability for larger group performances.

Figure 4. Local Mesh Fitting process for conducting a dance dataset.

Overall, while significant strides have been made in music-driven dance generation, further research is needed to overcome the challenges of scalability and synchronization in group dance synthesis.

Motion Manifold Learning

Motion manifold learning has garnered significant attention in computer vision and artificial intelligence, primarily aiming to understand the fundamental structures underlying human movement and dynamics. This approach offers the capability to generate human motion patterns, providing insights into intrinsic motion dynamics, managing nonlinear relationships within motion data, and learning contextual and hierarchical representations. Consequently, numerous methodologies have emerged, each contributing distinct perspectives to advance the comprehension and synthesis of human motion.

Holden et al. pioneered motion manifold learning by generating character movements from high-level parameters mapped to a motion manifold, eliminating the need for manual preprocessing while enabling natural, smooth post-generation editing of motion sequences. MotionCLIP introduced a 3D human motion autoencoder aligned with CLIP's semantic space, facilitating semantic text-based motion generation, disentangled editing, and abstract language specification. This approach capitalizes on CLIP's rich semantic knowledge, integrating it within the motion manifold for enhanced control and interpretation of human movement. Sun et al. further advanced this field by employing VQ-VAE to learn a low-dimensional motion manifold, effectively refining motion sequences to improve coherence and continuity.

Figure 5. Manifold conduction.
In the context of group dance generation, motion manifold learning presents a promising solution to scalability challenges, particularly the restricted number of dancers present in most datasets. By learning a distribution over dance motions within the manifold, this approach enables the generation of synchronized, cohesive group dance sequences, potentially overcoming dataset limitations. This direction of research highlights the potential of manifold-based methods to enhance scalability, allowing for the synthesis of large-scale, realistic group dance performances that maintain temporal and spatial harmony.

Next

In the next part, we will dive deeply into the main proposal that leveragw manifold learning to handle scalability of group dance generation.

Reducing Non-IID Effects in Federated Autonomous Driving with Contrastive Divergence Loss (Part 2)

In the previous post, we delved into the realm of federated learning in the context of autonomous driving, exploring its potential and the challenges posed by non-IID data distribution. We also discussed how contrastive divergence offers a promising avenue to tackle the non-IID problem in federated learning setups. Building upon that foundation, in this post, we will delve deeper into the integration of contrastive divergence loss into federated learning frameworks. We'll explore the mechanics of incorporating this loss function into the federated learning process and examine its potential implications for improving model performance and addressing non-IID data challenges in autonomous driving scenarios. We unravel the intricacies of leveraging contrastive divergence within federated learning paradigms to advance the capabilities of autonomous driving systems.

1. Overview

Motivation: The effectiveness of federated learning algorithms in autonomous driving hinges on two critical factors: firstly, the ability of each local silo to glean meaningful insights from its own data, and secondly, the synchronization among neighboring silos to mitigate the impact of the non-IID problem. Recent efforts have primarily focused on addressing these challenges through various means, including optimizing accumulation processes and optimizers, proposing novel network topologies, or leveraging robust deep networks capable of handling the distributed nature of the data. However, as highlighted by Duan et al., the indiscriminate adoption of high-performance deep architectures and their associated optimizations in centralized local learning scenarios can lead to increased weight variance among local silos during the accumulation process in federated setups. This variance detrimentally impacts model convergence and may even induce divergence, underscoring the need for nuanced approaches to ensure the efficacy of federated learning in autonomous driving contexts.

Figure 1. The Siamese setup when our CDL is applied for training federated autonomous driving model. ResNet-8 is used in the backbone and sub-network in the Siamese setup. During inference, the sub-network will be removed. Dotted lines represent the backward process. Our CDL has two components: the positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} and the negative regularize term Lcd\mathcal{L}_{\rm {cd^-}}. The local regression loss Llr\mathcal{L}_{\rm {lr}} for automatic steering prediction is calculated only from the backbone network.

Siamese Network Approach: In our study, we propose a novel approach to directly tackle the non-IID problem within each local silo by addressing the challenges of learning optimal features and achieving synchronization separately. Our strategy involves implementing \textit{two distinct networks within each silo}: one network is dedicated to extracting meaningful features from local image data, while the other focuses on minimizing the distribution gap between the current model weights and those of neighboring silos. To facilitate this, we employ a Siamese Network architecture, comprising two branches. The first branch, serving as the backbone network, is tasked with learning local image features for autonomous steering using a local regression loss Llr\mathcal{L}_{\rm {lr}}, while simultaneously incorporating a positive contrastive divergence loss Lcd+\mathcal{L}_{\rm {cd^+}} to assimilate knowledge from neighboring silos. Meanwhile, the second branch, referred to as the sub-network, functions to regulate divergence factors arising from the backbone's knowledge through a contrastive regularizer term Lcd\mathcal{L}_{\rm {cd^-}}. See Figure.1 for more detail.

In practice, the sub-network initially adopts the same weights as the backbone during the initial communication round. However, starting from the subsequent communication rounds, once the backbone undergoes accumulation using Equation below, each silo's local model is trained using the contrastive divergence loss. The sub-network produces auxiliary features of identical dimensions to the output features of the backbone. Throughout training, we anticipate minimal discrepancies in weights between the backbone and the sub-network when employing the contrastive divergence loss. Synchronization of weights across all silos occurs when gradients from the backbone and sub-network learning processes exhibit minimal disparity.

θi(k+1)=θi(k)αk1mh=1mLlr(θi(k),ξih(k))\theta_i\left(k + 1\right) = {\theta}_i\left(k\right)-\alpha_{k}\frac{1}{m}\sum^m_{h=1}\nabla \mathcal{L}_{\rm {lr}}\left({\theta}_i\left(k\right),\xi_i^h\left(k\right)\right)

where Llr\mathcal{L}_{\rm {lr}} is the local regression loss for autonomous steering.

2. Contrastive Divergence Loss

In practice, we've noticed that the initial phases of federated learning often yield subpar accumulated models. Unlike other approaches that tackle the non-IID issue by refining the accumulation step whenever silos transmit their models, we directly mitigate the impact of divergence factors during the local learning phase of each silo. Our method aims to minimize the discrepancy between the distribution of accumulated weights from neighboring silos in the backbone network (representing divergence factors) and the weights specific to silo ii in the sub-network (comprising locally learned knowledge). Once the distribution between silos achieves an acceptable level of synchronization, we reduce the influence of the sub-network and prioritize the steering angle prediction task. Inspired by the contrastive loss of the original Siamese Network, our proposed Contrastive Divergence Loss is formulated as follows:

Lcd=βLcd++(1β)Lcd=βH(θib,θis)+(1β)H(θis,θib)\mathcal{L}_{\rm {cd}} = \beta \mathcal{L}_{\rm {cd^+}} + (1-\beta) \mathcal{L}_{\rm {cd^-}} = \beta \mathcal{H}(\theta^b_i, \theta^s_i) + (1-\beta) \mathcal{H}(\theta^s_i,\theta^b_i)

where Lcd+\mathcal{L}_{\rm {cd^+}} is the positive contrastive divergence term and Lcd\mathcal{L}_{\rm {cd^-}} is the negative regularizer term; H\mathcal{H} is the Kullback-Leibler Divergence loss function

H(y^,y)=f(y^)log(f(y^)f(y))\mathcal{H}(\hat{y},y) = \sum \mathbf{f}(\hat{y}) \log\left(\frac{\mathbf{f}(\hat{y})}{\mathbf{f}(y)}\right)

where y^\hat{y} is the predicted representation, yy is dynamic soft label.

Consider Lcd+\mathcal{L}_{\rm {cd^+}} in Equation above as a Bayesian statistical inference task, our goal is to estimate the model parameters θb\theta^{b*} by minimizing the Kullback-Leibler divergence H(θib,θis)\mathcal{H}(\theta^b_i, \theta^s_i) between the measured regression probability distribution of the observed local silo P0(xθis)P_0 (x|\theta^s_i) and the accumulated model P(xθib)P (x|\theta^b_i). Hence, we can assume that the model distribution has a form of P(xθib)=eE(x,θib)/Z(θib)P (x|\theta^b_i) = e^{-E(x,\theta^b_i)}/Z(\theta^b_i), where Z(θib)Z(\theta^b_i) is the normalization term. However, evaluating the normalization term Z(θib)Z(\theta^b_i) is not trivial, which leads to risks of getting stuck in a local minimum. Inspired by Hinton, we use samples obtained through a Markov Chain Monte Carlo (MCMC) procedure with a specific initialization strategy to deal with the mentioned problem. Additionally inferred from Equation above, the Lcd+\mathcal{L}_{\rm {cd^+}} can be expressed under the SGD algorithm in a local silo by setting:

Lcd+=xP0(xθis)E(x;θib)θib+xQθib(xθis)E(x;θib)θib\mathcal{L}_{\rm {cd^+}} = -\sum_{x}P_0 (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i} + \sum_{x}Q_{\theta^b_i} (x|\theta^s_i)\frac{\partial E(x;\theta^b_i)}{\partial \theta^b_i}

where Qθib(xθis)Q_{\theta^b_i} (x|\theta^s_i) is the measured probability distribution on the samples obtained by initializing the chain at P0(xθis)P_0 (x|\theta^s_i) and running the Markov chain forward for a defined step.

Consider Lcd\mathcal{L}_{\rm {cd^-}} regularizer in Equation above as a Bayesian statistical inference task, we can calculate Lcd\mathcal{L}_{\rm {cd^-}} as in Equation above, however, the role of θs\theta^s and θb\theta^b is inverse:

Lcd=xP0(xθib)E(x;θis)θis+xQθis(xθib)E(x;θis)θis\mathcal{L}_{\rm {cd^-}}=-\sum_{x}P_0 (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} + \sum_{x}Q_{\theta^s_i} (x|\theta^b_i)\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}

We note that the key difference is that while the weight θib\theta^b_i of the backbone is updated by the accumulation process from Equation above, the weight θis\theta^s_i of the sub-network, instead, is not. This lead to different convergence behavior of contrastive divergence in Lcd+\mathcal{L}_{\rm {cd^+}} and Lcd\mathcal{L}_{\rm {cd^-}}. The negative regularizer term Lcd\mathcal{L}_{\rm {cd^-}} will converge to state θis\theta^{s*}_i provided Eθis\frac{\partial E}{\partial \theta^s_i} is bounded:

g(x,θis)=E(x;θis)θisxP0(x(θib,θis))E(x;θis)θisg(x,\theta^s_i) = \frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i} -\sum_{x}P_0 (x|(\theta^b_i,\theta^s_i))\frac{\partial E(x;\theta^s_i)}{\partial \theta^s_i}