Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 172 additions & 42 deletions medcat-v2/medcat/components/linking/embedding_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections import defaultdict
import logging
import math
import numpy as np

from medcat.utils.import_utils import ensure_optional_extras_installed
import medcat
Expand Down Expand Up @@ -85,8 +86,9 @@ def __init__(self, cdb: CDB, config: Config) -> None:
]
for name in self._name_keys
]
self._initialize_filter_structures()

def create_embeddings(self,
def create_embeddings(self,
embedding_model_name: Optional[str] = None,
max_length: Optional[int] = None,
):
Expand Down Expand Up @@ -281,6 +283,170 @@ def _get_context_vectors(
texts.append(text)
return self._embed(texts, self.device)

def _initialize_cui_name_mapping(self) -> None:
"""Call this once during initialization to pre-compute CUI->name."""
self._cui_to_name_mask = {}

for cui, cui_idx in self._cui_to_idx.items():
mask = torch.tensor(
[cui_idx in name_cui_idxs
for name_cui_idxs in self._name_to_cui_idxs],
dtype=torch.bool,
device=self.device
)
self._cui_to_name_mask[cui] = mask

# Cache _has_cuis_all as well
self._has_cuis_all_cached = torch.tensor(
[bool(self.cdb.name2info[name]["per_cui_status"])
for name in self._name_keys],
device=self.device,
dtype=torch.bool,
)

def _initialize_filter_structures(self) -> None:
"""Call once during initialization to create efficient lookup structures."""
# Build an inverted index: cui_idx -> list of name indices that contain it
# This is the KEY optimization - we flip the lookup direction
if not hasattr(self, '_cui_idx_to_name_idxs'):
cui2name_indices: defaultdict[
int, list[int]] = defaultdict(list)

for name_idx, cui_idxs in enumerate(self._name_to_cui_idxs):
for cui_idx in cui_idxs:
cui2name_indices[cui_idx].append(name_idx)

# Convert lists to numpy arrays for faster indexing
self._cui_idx_to_name_idxs = {
cui_idx: np.array(name_idxs, dtype=np.int32)
for cui_idx, name_idxs in cui2name_indices.items()
}

# Cache _has_cuis_all
if not hasattr(self, '_has_cuis_all_cached'):
self._has_cuis_all_cached = torch.tensor(
[bool(self.cdb.name2info[name]["per_cui_status"])
for name in self._name_keys],
device=self.device,
dtype=torch.bool,
)

def _get_include_filters_1cui(
self, cui: str, n: int) -> torch.Tensor:
"""Optimized single CUI include filter using inverted index."""
if cui not in self._cui_to_idx:
return torch.zeros(n, dtype=torch.bool, device=self.device)

cui_idx = self._cui_to_idx[cui]

# Use inverted index: get all name indices that contain this CUI
if cui_idx in self._cui_idx_to_name_idxs:
name_indices = self._cui_idx_to_name_idxs[cui_idx]

# Create mask by setting specific indices to True
allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device)
allowed_mask[torch.from_numpy(name_indices).to(self.device)] = True
return allowed_mask
else:
return torch.zeros(n, dtype=torch.bool, device=self.device)

def _get_include_filters_multi_cui(
self, include_set: Set[str], n: int) -> torch.Tensor:
"""Optimized multi-CUI include filter using inverted index."""
include_cui_idxs = [
self._cui_to_idx[cui] for cui in include_set
if cui in self._cui_to_idx
]

if not include_cui_idxs:
return torch.zeros(n, dtype=torch.bool, device=self.device)

# Collect all name indices from inverted index
all_name_indices_list: list[np.ndarray] = []
for cui_idx in include_cui_idxs:
if cui_idx in self._cui_idx_to_name_idxs:
all_name_indices_list.append(
self._cui_idx_to_name_idxs[cui_idx])

if not all_name_indices_list:
return torch.zeros(n, dtype=torch.bool, device=self.device)

# Concatenate and get unique indices
all_name_indices = np.unique(
np.concatenate(all_name_indices_list))

# Create mask
allowed_mask = torch.zeros(n, dtype=torch.bool, device=self.device)
allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = True
return allowed_mask

def _get_include_filters(
self, include_set: Set[str], n: int) -> torch.Tensor:
"""Route to appropriate include filter method."""
if len(include_set) == 1:
cui = next(iter(include_set))
return self._get_include_filters_1cui(cui, n)
else:
return self._get_include_filters_multi_cui(
include_set, n)

def _get_exclude_filters_1cui(
self, allowed_mask: torch.Tensor, cui: str) -> torch.Tensor:
"""Optimized single CUI exclude filter using inverted index."""
if cui not in self._cui_to_idx:
return allowed_mask

cui_idx = self._cui_to_idx[cui]

if cui_idx in self._cui_idx_to_name_idxs:
name_indices = self._cui_idx_to_name_idxs[cui_idx]
# Set specific indices to False
allowed_mask[
torch.from_numpy(name_indices).to(self.device)] = False

return allowed_mask

def _get_exclude_filters_multi_cui(
self, allowed_mask: torch.Tensor, exclude_set: Set[str],
) -> torch.Tensor:
"""Optimized multi-CUI exclude filter using inverted index."""
exclude_cui_idxs = [
self._cui_to_idx[cui] for cui in exclude_set
if cui in self._cui_to_idx
]

if not exclude_cui_idxs:
return allowed_mask

# Collect all name indices to exclude
_all_name_indices: list[np.ndarray] = []
for cui_idx in exclude_cui_idxs:
if cui_idx in self._cui_idx_to_name_idxs:
_all_name_indices.append(self._cui_idx_to_name_idxs[cui_idx])

if _all_name_indices:
all_name_indices = np.unique(np.concatenate(_all_name_indices))
allowed_mask[torch.from_numpy(all_name_indices).to(self.device)] = False

return allowed_mask

def _get_exclude_filters(
self, exclude_set: Set[str], n: int) -> torch.Tensor:
"""Route to appropriate exclude filter method."""
# Start with all allowed
allowed_mask = torch.ones(n, dtype=torch.bool, device=self.device)

if not exclude_set:
return allowed_mask

if len(exclude_set) == 1:
cui = next(iter(exclude_set))
return self._get_exclude_filters_1cui(
allowed_mask, cui)
else:
return self._get_exclude_filters_multi_cui(
allowed_mask, exclude_set)

def _set_filters(self) -> None:
include_set = self.cnf_l.filters.cuis
exclude_set = self.cnf_l.filters.cuis_exclude
Expand All @@ -295,54 +461,18 @@ def _set_filters(self) -> None:
return

n = len(self._name_keys)
allowed_mask = torch.empty(n, dtype=torch.bool, device=self.device)

if include_set:
# if in include set, ignore exclude set.
allowed_mask[:] = False
include_cui_idxs = {
self._cui_to_idx[cui] for cui in include_set if cui in self._cui_to_idx
}
include_idxs = [
name_idx
for name_idx, name_cui_idxs in enumerate(self._name_to_cui_idxs)
if any(cui in include_cui_idxs for cui in name_cui_idxs)
]
allowed_mask[
torch.tensor(include_idxs, dtype=torch.long, device=self.device)
] = True
allowed_mask = self._get_include_filters(
include_set, n)
else:
# only look at exclude if there's no include set
allowed_mask[:] = True
if exclude_set:
exclude_cui_idxs = {
self._cui_to_idx[cui]
for cui in exclude_set
if cui in self._cui_to_idx
}
exclude_idxs = [
i
for i, name_cui_idxs in enumerate(self._name_to_cui_idxs)
if any(ci in exclude_cui_idxs for ci in name_cui_idxs)
]
allowed_mask[
torch.tensor(exclude_idxs, dtype=torch.long, device=self.device)
] = False
allowed_mask = self._get_exclude_filters(
exclude_set, n)

# checking if a name has at least 1 cui related to it.
_has_cuis_all = torch.tensor(
[
bool(self.cdb.name2info[name]["per_cui_status"])
for name in self._name_keys
],
device=self.device,
dtype=torch.bool,
)
self._valid_names = _has_cuis_all & allowed_mask
self._valid_names = self._has_cuis_all_cached & allowed_mask
self._last_include_set = set(include_set) if include_set is not None else None
self._last_exclude_set = set(exclude_set) if exclude_set is not None else None


def _disambiguate_by_cui(
self, cui_candidates: list[str], scores: Tensor
) -> tuple[str, float]:
Expand Down
Loading