Skip to content
Open
9 changes: 7 additions & 2 deletions dice_ml/explainer_interfaces/dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
# post-hoc operation on continuous features to enhance sparsity - only for public data
if posthoc_sparsity_param is not None and posthoc_sparsity_param > 0 and 'data_df' in self.data_interface.__dict__:
self.final_cfs_df_sparse = copy.deepcopy(self.final_cfs)
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse, query_instance,
self.final_cfs_df_sparse = self.do_posthoc_sparsity_enhancement(self.final_cfs_df_sparse,
query_instance,
posthoc_sparsity_param,
posthoc_sparsity_algorithm)
else:
Expand All @@ -260,10 +261,14 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
if total_cfs_found < total_CFs:
self.elapsed = timeit.default_timer() - start_time
m, s = divmod(self.elapsed, 60)
print('Only %d (required %d) ' % (total_cfs_found, self.total_CFs),
print('Only %d (required %d) ' % (total_cfs_found, total_CFs),
'Diverse Counterfactuals found for the given configuation, perhaps ',
'change the query instance or the features to vary...' '; total time taken: %02d' % m,
'min %02d' % s, 'sec')
elif total_cfs_found == 0:
print(
'No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec')

Expand Down
28 changes: 16 additions & 12 deletions dice_ml/explainer_interfaces/dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def do_random_init(self, num_inits, features_to_vary, query_instance, desired_cl
def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desired_range):
cfs = self.label_encode(cfs)
cfs = cfs.reset_index(drop=True)

self.cfs = np.zeros((self.population_size, self.data_interface.number_of_features))
for kx in range(self.population_size):
row = []
for kx in range(self.population_size*5):
if kx >= len(cfs):
break
one_init = np.zeros(self.data_interface.number_of_features)
Expand All @@ -143,16 +142,18 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir
one_init[jx] = query_instance[jx]
else:
one_init[jx] = np.random.choice(self.feature_range[feature])
self.cfs[kx] = one_init
t = tuple(one_init)
if t not in row:
row.append(t)
if len(row) == self.population_size:
break
kx += 1
self.cfs = np.array(row)

new_array = [tuple(row) for row in self.cfs]
uniques = np.unique(new_array, axis=0)

if len(uniques) != self.population_size:
if len(self.cfs) != self.population_size:
remaining_cfs = self.do_random_init(
self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range)
self.cfs = np.concatenate([uniques, remaining_cfs])
self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range)
self.cfs = np.concatenate([self.cfs, remaining_cfs])

def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range,
desired_class,
Expand Down Expand Up @@ -466,8 +467,8 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
if rest_members > 0:
new_generation_2 = np.zeros((rest_members, self.data_interface.number_of_features))
for new_gen_idx in range(rest_members):
parent1 = random.choice(population[:int(len(population) / 2)])
parent2 = random.choice(population[:int(len(population) / 2)])
parent1 = random.choice(population[:max(int(len(population) / 2), 1)])
parent2 = random.choice(population[:max(int(len(population) / 2), 1)])
child = self.mate(parent1, parent2, features_to_vary, query_instance)
new_generation_2[new_gen_idx] = child

Expand Down Expand Up @@ -514,6 +515,9 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class,
if len(self.final_cfs) == self.total_CFs:
print('Diverse Counterfactuals found! total time taken: %02d' %
m, 'min %02d' % s, 'sec')
elif len(self.final_cfs) == 0:
print('No Counterfactuals found for the given configuration, perhaps try with different parameters...',
'; total time taken: %02d' % m, 'min %02d' % s, 'sec')
else:
print('Only %d (required %d) ' % (len(self.final_cfs), self.total_CFs),
'Diverse Counterfactuals found for the given configuation, perhaps ',
Expand Down
10 changes: 8 additions & 2 deletions dice_ml/explainer_interfaces/dice_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,17 @@ class of query_instance for binary classification.
cfs_df = None
candidate_cfs = pd.DataFrame(
np.repeat(query_instance.values, sample_size, axis=0), columns=query_instance.columns)
# Loop to change one feature at a time, then two features, and so on.
# Loop to change one feature at a time ##->(NOT TRUE), then two features, and so on.
for num_features_to_vary in range(1, len(self.features_to_vary)+1):
# commented lines allow more values to change as num_features_to_vary increases, instead of .at you should use .loc
# is deliberately left commented out to let you choose.
# is slower, but more complete and still faster than genetic/KDtree
# selected_features = np.random.choice(self.features_to_vary, (sample_size, num_features_to_vary), replace=True)
selected_features = np.random.choice(self.features_to_vary, (sample_size, 1), replace=True)
for k in range(sample_size):
candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]]
candidate_cfs.at[k, selected_features[k][0]] = random_instances._get_value(k, selected_features[k][0])
# If you only want to change one feature, you should use _get_value
# candidate_cfs.iloc[k][selected_features[k]]=random_instances.iloc[k][selected_features[k]]
scores = self.predict_fn(candidate_cfs)
validity = self.decide_cf_validity(scores)
if sum(validity) > 0:
Expand Down
62 changes: 41 additions & 21 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
desired_class="opposite", desired_range=None,
permitted_range=None, features_to_vary="all",
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
posthoc_sparsity_algorithm=None, verbose=False, **kwargs):
"""General method for generating counterfactuals.

:param query_instances: Input point(s) for which counterfactuals are to be generated.
Expand Down Expand Up @@ -81,11 +81,23 @@ def generate_counterfactuals(self, query_instances, total_CFs,
if total_CFs <= 0:
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
if total_CFs > 10:
if posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'binary'
elif total_CFs > 50 and posthoc_sparsity_algorithm == 'linear':
import warnings
warnings.warn(
"The number of counterfactuals (total_CFs={}) generated per query instance could take much time; "
"if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to "
"'binary' search!".format(total_CFs))
elif posthoc_sparsity_algorithm is None:
posthoc_sparsity_algorithm = 'linear'

cf_examples_arr = []
query_instances_list = []
if isinstance(query_instances, pd.DataFrame):
for ix in range(query_instances.shape[0]):
query_instances_list.append(query_instances[ix:(ix+1)])
query_instances_list.append(query_instances[ix:(ix + 1)])
elif isinstance(query_instances, Iterable):
query_instances_list = query_instances

Expand Down Expand Up @@ -179,11 +191,14 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query

if feature not in features_to_vary and permitted_range is not None:
if feature in permitted_range and feature in self.data_interface.continuous_feature_names:
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][1]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
if not permitted_range[feature][0] <= query_instance[feature].values[0] <= permitted_range[feature][\
1]:
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")
elif feature in permitted_range and feature in self.data_interface.categorical_feature_names:
if query_instance[feature].values[0] not in self.feature_range[feature]:
raise ValueError("Feature:", feature, "is outside the permitted range and isn't allowed to vary.")
raise ValueError("Feature:", feature,
"is outside the permitted range and isn't allowed to vary.")

def local_feature_importance(self, query_instances, cf_examples_list=None,
total_CFs=10,
Expand Down Expand Up @@ -429,12 +444,13 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post
cfs_preds_sparse = []

for cf_ix in list(final_cfs_sparse.index):
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
for feature in features_sorted:
# current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names])
# feat_ix = self.data_interface.continuous_feature_names.index(feature)
diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature])
if(abs(diff) <= quantiles[feature]):
if (abs(diff) <= quantiles[feature]):
if posthoc_sparsity_algorithm == "linear":
final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix,
feature, final_cfs_sparse, current_pred)
Expand All @@ -455,13 +471,14 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
query_instance greedily until the prediction class changes."""

old_diff = diff
change = (10**-decimal_prec[feature]) # the minimal possible change for a feature
change = (10 ** -decimal_prec[feature]) # the minimal possible change for a feature
current_pred = current_pred_orig
if self.model.model_type == ModelTypes.Classifier:
while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)):
while ((abs(diff) > 10e-4) and (np.sign(diff * old_diff) > 0) and self.is_cf_valid(current_pred)):
old_val = int(final_cfs_sparse.at[cf_ix, feature])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
final_cfs_sparse.at[cf_ix, feature] += np.sign(diff) * change
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
old_diff = diff

if not self.is_cf_valid(current_pred):
Expand Down Expand Up @@ -494,11 +511,12 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = query_instance[feature].iat[0]

while left <= right:
current_val = left + ((right - left)/2)
current_val = left + ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break
Expand All @@ -513,19 +531,20 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f
right = int(final_cfs_sparse.at[cf_ix, feature])

while right >= left:
current_val = right - ((right - left)/2)
current_val = right - ((right - left) / 2)
current_val = round(current_val, decimal_prec[feature])

final_cfs_sparse.at[cf_ix, feature] = current_val
current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])
current_pred = self.predict_fn_for_sparsity(
final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names])

if current_val == right or current_val == left:
break

if self.is_cf_valid(current_pred):
right = current_val - (10**-decimal_prec[feature])
right = current_val - (10 ** -decimal_prec[feature])
else:
left = current_val + (10**-decimal_prec[feature])
left = current_val + (10 ** -decimal_prec[feature])

return final_cfs_sparse

Expand Down Expand Up @@ -567,7 +586,7 @@ def infer_target_cfs_class(self, desired_class_input, original_pred, num_output_
raise UserConfigValidationException("Desired class not present in training data!")
else:
raise UserConfigValidationException("The target class for {0} could not be identified".format(
desired_class_input))
desired_class_input))

def infer_target_cfs_range(self, desired_range_input):
target_range = None
Expand All @@ -586,7 +605,7 @@ def decide_cf_validity(self, model_outputs):
pred = model_outputs[i]
if self.model.model_type == ModelTypes.Classifier:
if self.num_output_nodes == 2: # binary
pred_1 = pred[self.num_output_nodes-1]
pred_1 = pred[self.num_output_nodes - 1]
validity[i] = 1 if \
((self.target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(self.target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else 0
Expand Down Expand Up @@ -623,7 +642,7 @@ def is_cf_valid(self, model_score):
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
return validity
if self.num_output_nodes == 2: # binary
pred_1 = model_score[self.num_output_nodes-1]
pred_1 = model_score[self.num_output_nodes - 1]
validity = True if \
((target_cf_class == 0 and pred_1 <= self.stopping_threshold) or
(target_cf_class == 1 and pred_1 >= self.stopping_threshold)) else False
Expand Down Expand Up @@ -699,7 +718,8 @@ def round_to_precision(self):
for ix, feature in enumerate(self.data_interface.continuous_feature_names):
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
if self.final_cfs_df_sparse is not None:
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(
precisions[ix])

def _check_any_counterfactuals_computed(self, cf_examples_arr):
"""Check if any counterfactuals were generated for any query point."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,4 +329,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
1 change: 1 addition & 0 deletions requirements-linting.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
flake8==3.9.2
flake8-bugbear==21.11.29
flake8-blind-except==0.1.1
flake8-breakpoint
flake8-builtins==1.5.3
flake8-logging-format==0.6.0
flake8-nb==0.3.0
Expand Down