Source code for phenotrex.transforms.resampling
#
# Created by Lukas Lüftinger on 05/02/2019.
#
from typing import List
import numpy as np
from numpy.random import RandomState
from sklearn.utils import resample
from phenotrex.util.logging import get_logger
from phenotrex.structure.records import TrainingRecord
[docs]class TrainingRecordResampler:
"""
Instantiates an object which can generate versions of a TrainingRecord
resampled to defined completeness and contamination levels.
Requires prior fitting with full List[TrainingRecord]
to get sources of contamination for both classes.
:param random_state: Randomness seed to use while resampling
:param verb: Toggle verbosity
"""
def __init__(
self,
random_state: float = None,
verb: bool = False
):
self.logger = get_logger(initname=self.__class__.__name__, verb=verb)
self.random_state = random_state if type(random_state) is RandomState else RandomState(random_state)
self.conta_source_pos = None
self.conta_source_neg = None
self.fitted = False
[docs] def fit(self, records: List[TrainingRecord]):
"""
Fit TrainingRecordResampler on full TrainingRecord list
to determine set of positive and negative features for contamination resampling.
:param records: the full List[TrainingRecord] on which ml training will commence.
:return: True if fitting was performed, else False.
"""
if self.fitted:
self.logger.warning("TrainingRecordSampler already fitted on full TrainingRecord data."
" Refusing to fit again.")
return False
total_neg_featureset = []
total_pos_featureset = []
for record in records:
if record.trait_sign == 1:
total_pos_featureset.append(record.features)
elif record.trait_sign == 0:
total_neg_featureset.append(record.features)
else:
raise RuntimeError("Unexpected record sign found. Aborting.")
self.conta_source_pos = np.array(total_pos_featureset)
self.conta_source_neg = np.array(total_neg_featureset)
self.fitted = True
return True
[docs] def get_resampled(
self,
record: TrainingRecord,
comple: float = 1.,
conta: float = 0.
) -> TrainingRecord:
"""
Resample a TrainingRecord to defined completeness and contamination levels.
Comple=1, Conta=1 will double set size.
:param comple: completeness of returned TrainingRecord features. Range: 0 - 1
:param conta: contamination of returned TrainingRecord features. Range: 0 - 1
:param record: the input TrainingRecord
:return: a resampled TrainingRecord.
"""
if not self.fitted:
raise RuntimeError(
"TrainingRecordResampler is not fitted on full TrainingRecord set. Aborting."
)
if not 0 <= comple <= 1 or not 0 <= conta <= 1:
raise RuntimeError("Invalid comple/conta settings. Must be between 0 and 1.")
features = record.features
n_features_comple = int(np.floor(len(features) * comple))
# make incomplete
incomplete_features = resample(
features, replace=False, n_samples=n_features_comple, random_state=self.random_state
)
self.logger.info(
f"Reduced features of TrainingRecord {record.identifier} "
f"from {len(features)} to {n_features_comple}."
)
# make contaminations
record_class = record.trait_sign
if record.trait_sign == 1:
# guard against very small sample errors after StratifiedKFold
if self.conta_source_neg.shape[0] == 1:
source_set_id = 0
else:
source_set_id = self.random_state.randint(0, self.conta_source_neg.shape[0] - 1)
conta_source = list(self.conta_source_neg[source_set_id])
elif record.trait_sign == 0:
if self.conta_source_pos.shape[0] == 1:
source_set_id = 0
else:
source_set_id = self.random_state.randint(0, self.conta_source_pos.shape[0] - 1)
conta_source = list(self.conta_source_pos[source_set_id])
else:
raise RuntimeError(f"Unexpected record sign found: {record.trait_sign}. Aborting.")
n_features_conta = min(len(conta_source), int(np.floor(len(conta_source) * conta)))
conta_features = list(self.random_state.choice(
a=conta_source, size=n_features_conta, replace=False
))
# TODO: what if not enough conta features?
self.logger.info(
f"Enriched features of TrainingRecord {record.identifier} "
f"with {len(conta_features)} features from "
f"{'positive' if record_class == 0 else 'negative'} set."
)
new_record = TrainingRecord(
identifier=record.identifier,
trait_name=record.trait_name,
trait_sign=record.trait_sign,
feature_type=record.feature_type,
features=incomplete_features + conta_features,
group_name=None,
group_id=None
)
return new_record