[BIBM 2025] CafeMed: Causal Attention Fusion Enhanced Medication Recommendation
CafeMed is a medication recommendation system that leverages causal attention fusion mechanisms to enhance prescription accuracy and safety in clinical settings.
CafeMed/
├── data/
│ ├── input/ # Raw data and mapping files
│ ├── output/ # Processed data files
│ ├── graphs/ # Causal graph data
│ ├── processing.py # Data preprocessing script
│ └── ddi_mask_H.py # DDI mask generation script
├── src/ # Source code
│ ├── modules/ # Model definitions
│ ├── util.py # Utilities and metrics
│ ├── training.py # Training functions
│ └── main.py # Main training/evaluation script
└── saved/
├── trained_model/ # Pre-trained model example
└── parameter_report.txt # Training parameters log
drug-atc.csv,ndc2atc_level4.csv,ndc2rxnorm_mapping.txt- Drug code mapping filesidx2ndc.pkl- ATC-4 to RxNorm code mappingidx2drug.pkl- Drug ID to SMILES string dictionary
voc_final.pkl- Vocabulary mappings for diagnosis/procedure/medication codesddi_A_final.pkl- Drug-drug interaction adjacency matrixddi_matrix_H.pkl- DDI mask structure (generated byddi_mask_H.py)records_final.pkl- Final EHR records (user must process according to instructions)
causal_graph.pkl- Causal graphs in DAG formatDiag_Med_causal_effect.pkl- Diagnosis-medication causal effectsProc_Med_causal_effect.pkl- Procedure-medication causal effects
python == 3.8.17
torch == 2.0.1
dill == 0.3.6
numpy == 1.22.3
pandas == 2.0.2
torch-geometric == 2.3.1
cdt == 0.6.0
dowhy == 0.10.1
statsmodels == 0.14.0pip install torch==2.0.1 torch-geometric==2.3.1
pip install dill==0.3.6 numpy==1.22.3 pandas==2.0.2
pip install cdt==0.6.0 dowhy==0.10.1 statsmodels==0.14.0MIMIC-III Dataset:
- Apply for access at https://physionet.org/content/mimiciii/1.4/
- Download and extract the following files to
data/input/:PROCEDURES_ICD.csv.gzPRESCRIPTIONS.csv.gzDIAGNOSES_ICD.csv.gz
MIMIC-IV Dataset: Follow the same process as MIMIC-III.
DDI Data:
Download the DDI file from Google Drive and place it in the data/input folder.
# Process raw data
python data/processing.py
# Generate DDI mask
python data/ddi_mask_H.pycd src
python main.py
