Causal Inference Medical Imaging Deep Learning OCT ResNet Representation Learning

Abstract

Standard deep learning models for medical image analysis excel at learning correlational patterns but often fail to capture the underlying causal relationships inherent in disease pathology. This paper investigates a novel approach to enhance diagnostic accuracy by integrating principles of causal inference directly into a deep learning framework. We present a systematic study on the classification of retinal diseases from Optical Coherence Tomography (OCT) images using a ResNet18 architecture augmented with causality-aware modules. We explore two distinct mechanisms: a concatenation-based method ("Cat") and a feature-weighting method ("Mulcat"), along with different causality map computations (Lehmer mean vs. Max) and extraction modes (Full/Bool and Causes/Effects). Our key finding: a Mulcat model using a Lehmer-derived causality map with boolean weighting of causal features achieves 88.90% accuracy, demonstrating that selective integration of causal signals, rather than amplification, is the key to improved performance.

Causality-Aware Architecture Overview
Fig. 1, Overview of the causality-aware ResNet18 architecture integrating causal inference modules for enhanced OCT classification.
• • •

Introduction

Optical Coherence Tomography (OCT) provides high-resolution cross-sectional images of the retina, critical for diagnosing diabetic retinopathy, macular degeneration, and glaucoma. CNNs like ResNet have been widely applied to automate classification of these images, often matching human-level performance.

However, a fundamental limitation of traditional CNNs is their reliance on statistical correlations. These models recognize patterns without an explicit understanding of the cause-and-effect relationships that define disease progression. This "correlation, not causation" paradigm can lead to brittle models that struggle with out-of-distribution examples, a critical drawback in high-stakes medical applications.

Recent advances in causal inference offer a promising direction. By developing methods to estimate and integrate causal signals from data, it is possible to guide a neural network to learn features that are causally relevant, not merely correlated. This study presents a rigorous empirical analysis of causality-aware ResNet18 variants, uncovering a novel and counter-intuitive configuration that yields superior performance.

Methodology

Our framework augments the ResNet18 backbone with specialized modules that process a computed "causality map", a k×k matrix where each element C(i,j) quantifies the causal influence of feature map Fi on Fj.

Causality Map Generation

After the image is processed by ResNet18's convolutional layers, two methods construct the causality map:

Max Method

A correlation-based proxy estimating causal influence by finding the maximum value in the feature interaction tensor. Computationally efficient but limited by its inability to distinguish direction of influence.

Lehmer Mean Method

A flexible, parameter-controlled mean that can be tuned to be more or less sensitive to feature interactions. Through fine-tuning, parameter p = 0 (geometric mean) was found optimal for training from scratch, suggesting multiplicative relationships best model feature influence in OCT analysis.

Causality-Aware Modules

Cat (Concatenate) Module

Flattens both the k feature maps and the k×k causality map into separate vectors, then concatenates them for the classifier. A passive strategy, the causal map is provided as supplementary metadata alongside visual features.

Mulcat (Multiply & Concatenate) Module

Uses the causality map to directly modulate feature maps via element-wise multiplication before classification. Controlled by two parameter pairs:

  • Direction (Causes / Effects): Focus on root-cause features vs. downstream consequences
  • Weighing Mode (Full / Bool): Continuous weighting vs. binary gating of causally relevant features

Experiments & Results

Lehmer Parameter Fine-Tuning

A preliminary experiment determined the optimal Lehmer parameter. Training from scratch with p = 0 achieved 91% accuracy, far outperforming other values and pretrained initializations.

Lehmer Value (a) Accuracy (Pretrained) Accuracy (From Scratch)
047%91%
147%73%
-177%35%
-283%47%
Table 1, Lehmer parameter fine-tuning. p = 0 (geometric mean) is clearly optimal when training from scratch.

Main Classification Results

Model Factor Mode Direction Test Accuracy (%)
ResNet18 (baseline)93.63
+ Cat (Max)93.65
+ Cat (Lehmer)91.58
+ Mulcat (Max)FullCauses56.88
+ Mulcat (Max)FullEffects78.34
+ Mulcat (Max)BoolEffects83.69
+ Mulcat (Lehmer)FullCauses88.65
+ Mulcat (Lehmer)FullEffects74.48
+ Mulcat (Lehmer)BoolCauses88.90
+ Mulcat (Lehmer)BoolEffects86.04
Table 2, Main results on the OCT dataset. The Mulcat (Lehmer, Bool, Causes) configuration is the best-performing causal variant.

Key Observation

The "Full" weighting mode, which directly amplifies features by their causal scores, consistently performs poorly, collapsing to 56.88% in the worst case. The "Bool" mode, treating causality as a selection gate rather than an amplifier, proves far more robust. More causal signal is not always better; selective integration is the key.

Discussion

The dramatic failure of the Full weighting mode suggests that raw causal scores are too noisy to be used as direct multipliers. Amplifying a single causally dominant feature biases the model, ignoring subtler visual cues that are also diagnostically relevant.

The Bool mode's success demonstrates that using causality as a selection mechanism, identifying which features matter, then weighting them equally, leads to more robust representations. The optimal direction (Causes over Effects) further aligns with clinical reasoning: identifying the root physiological cause is more valuable than detecting downstream consequences for accurate classification.

The geometric mean (p = 0) as the optimal Lehmer setting implies that multiplicative, rather than additive, relationships best capture how features interact in OCT retinal pathology.

Conclusion

This work demonstrates that the method of integrating causal information is critically important. Simple concatenation offers little benefit; aggressive amplification is detrimental; but nuanced, boolean-gated selection of causally relevant features based on the Lehmer geometric mean provides a highly competitive and interpretable result.

The identified configuration, Mulcat + Lehmer (p=0) + Bool + Causes, at 88.90% accuracy stands as the best causal variant and establishes a strong proof-of-concept for causality-aware representation learning in medical imaging. Future work will explore more sophisticated causality estimation methods, multi-modal extensions, and clinical validation of the discovered causal patterns.

• • •

References

  1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR.
  2. Terziyan, V., & Vitko, O. (2023). Causality-aware convolutional neural networks for advanced image classification and generation. Procedia Computer Science, 217, 495–506.
  3. Carloni, G., & Colantonio, S. (2024). Exploiting causality signals in medical images. Expert Systems with Applications, 249, 123433.
  4. Pearl, J. (2009). Causality: Models, Reasoning, and Inference. Cambridge University Press.
  5. Kermany, D. S., et al. (2018). Identifying Medical Diagnoses by Image-Based Deep Learning. Cell, 172(5), 1122–1131.
  6. Schölkopf, B. (2022). Causality for Machine Learning. arXiv:2206.14754.

Contact

For further information or collaboration opportunities:
Nikhileswara Rao Sulake, nikhil01446@gmail.com  ·  LinkedIn  ·  GitHub