-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(learner): add learner module
- Loading branch information
1 parent
1ed4e31
commit d6a5be1
Showing
34 changed files
with
3,332 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import argparse | ||
import numpy as np | ||
import pandas as pd | ||
import warnings | ||
from sklearn.exceptions import ConvergenceWarning | ||
from sklearn.preprocessing import MaxAbsScaler | ||
warnings.filterwarnings("ignore", category=ConvergenceWarning) | ||
|
||
from library.utils import * | ||
from library.model import * | ||
from library.active_learning import * | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser(description="Script for active learning and model training.") | ||
|
||
# Add all the arguments that were previously loaded from config.csv and use the CSV values as defaults | ||
parser.add_argument('--name_list', type=str, default='Yield1,Yield2,Yield3,Yield4,Yield5', help="Comma-separated list of names") | ||
parser.add_argument('--nb_rep', type=int, default=100, help="Number of repetitions") | ||
parser.add_argument('--flatten', type=str, choices=['true', 'false'], default='False', help="Whether to flatten data") | ||
parser.add_argument('--seed', type=int, default=85, help="Random seed") | ||
parser.add_argument('--nb_new_data_predict', type=int, default=3000, help="Number of new data points to predict") | ||
parser.add_argument('--nb_new_data', type=int, default=50, help="Number of new data points") | ||
parser.add_argument('--parameter_step', type=int, default=20, help="Parameter step") | ||
parser.add_argument('--test', type=int, default=1, help="Test flag") | ||
parser.add_argument('--n_group', type=int, default=15, help="Number of groups") | ||
parser.add_argument('--ks', type=int, default=20, help="ks parameter") | ||
parser.add_argument('--km', type=int, default=50, help="km parameter") | ||
parser.add_argument('--plot', type=str, choices=['true', 'false'], default='True', help="Whether to plot the results") | ||
parser.add_argument('--data_folder', type=str, default='data/top50', help="Folder containing data") | ||
parser.add_argument('--parameter_file', type=str, default='param.tsv', help="Parameter file path") | ||
parser.add_argument('--save_name', type=str, default='new_exp/plate3', help="Name for saving outputs") | ||
|
||
args = parser.parse_args() | ||
|
||
# Convert boolean-like strings to actual booleans | ||
args.flatten = args.flatten.lower() == 'true' | ||
args.plot = args.plot.lower() == 'true' | ||
|
||
# Convert comma-separated lists to actual Python lists | ||
args.name_list = args.name_list.split(',') | ||
|
||
return args | ||
|
||
def main(): | ||
args = parse_arguments() | ||
|
||
data_folder = args.data_folder | ||
name_list = args.name_list | ||
parameter_file = args.parameter_file | ||
nb_rep = args.nb_rep | ||
flatten = args.flatten | ||
seed = args.seed | ||
nb_new_data_predict = args.nb_new_data_predict | ||
nb_new_data = args.nb_new_data | ||
parameter_step = args.parameter_step | ||
test = args.test | ||
save_name = args.save_name | ||
n_group = args.n_group | ||
ks = args.ks | ||
km = args.km | ||
plot = args.plot | ||
|
||
# Proceed with the rest of the script logic | ||
element_list, element_max, sampling_condition = import_parameter(parameter_file, parameter_step) | ||
|
||
data, size_list = import_data(data_folder, verbose = True) | ||
check_column_names(data,element_list) | ||
|
||
no_element = len(element_list) | ||
y = np.array(data[name_list]) | ||
y_mean = np.nanmean(y, axis = 1) | ||
y_std = np.nanstd(y, axis = 1) | ||
X = data.iloc[:,0:no_element] | ||
|
||
params = {'kernel': [ | ||
C()*Matern(length_scale=10, nu=2.5)+ WhiteKernel(noise_level=1e-3, noise_level_bounds=(1e-3, 1e1)) | ||
], | ||
# 'alpha':[0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5]} | ||
'alpha':[0.05]} | ||
|
||
X_train, X_test, y_train, y_test = split_and_flatten(X, y, ratio = 0, flatten = flatten) | ||
scaler = MaxAbsScaler() | ||
X_train_norm = scaler.fit_transform(X_train) | ||
model = BayesianModels(n_folds= 10, model_type = 'gp', params=params) | ||
model.train(X_train_norm, y_train) | ||
|
||
if test: | ||
best_param = {'alpha': [model.best_params['alpha']],'kernel': [model.best_params['kernel']]} | ||
res = [] | ||
for i in range(nb_rep): | ||
X_train, X_test, y_train, y_test = split_and_flatten(X, y, ratio = 0.2, flatten = flatten) | ||
|
||
scaler = MaxAbsScaler() | ||
X_train_norm = scaler.fit_transform(X_train) | ||
X_test_norm = scaler.transform(X_test) | ||
|
||
eva_model = BayesianModels(model_type ='gp', params= best_param) | ||
eva_model.train(X_train_norm, y_train, verbose = False) | ||
y_pred, std_pred = eva_model.predict(X_test_norm) | ||
res.append(r2_score(y_test, y_pred)) | ||
|
||
plt.hist(res, bins = 20, color='orange') | ||
plt.title(f'Histogram of R2 for different testing subset, median= {np.median(res):.2f}', size = 12) | ||
|
||
X_new= sampling_without_repeat(sampling_condition, num_samples = nb_new_data_predict, existing_data=X_train, seed = seed) | ||
X_new_norm = scaler.transform(X_new) | ||
y_pred, std_pred = model.predict(X_new_norm) | ||
clusters = cluster(X_new_norm, n_group) | ||
|
||
ei = expected_improvement(y_pred, std_pred, max(y_train)) | ||
print("For EI:") | ||
ei_top, y_ei, ratio_ei, ei_cluster = find_top_elements(X_new, y_pred, clusters, ei, km, return_ratio= True) | ||
ei_top_norm = scaler.transform(ei_top) | ||
|
||
if plot: | ||
plot_selected_point(y_pred, std_pred, y_ei, 'EI selected') | ||
|
||
size_list.append(nb_new_data) | ||
y_mean = np.append(y_mean, y_ei) | ||
plot_each_round(y_mean,size_list, predict = True) | ||
|
||
plot_train_test(X_train_norm, ei_top_norm, element_list) | ||
|
||
fig, axes = plt.subplots(1, 1, figsize=(10, 4)) | ||
plot_heatmap(axes, ei_top_norm, y_ei, element_list, 'EI') | ||
plt.tight_layout() | ||
plt.xlabel("Yield: left-low, right-high") | ||
plt.show() | ||
|
||
X_ei = pd.DataFrame(ei_top, columns=element_list) | ||
name = save_name + '_ei'+ str(km) + '.csv' | ||
X_ei.to_csv(name, index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import pandas as pd | ||
import argparse | ||
import numpy as np | ||
import statsmodels.api as sm | ||
import matplotlib.pyplot as plt | ||
import os | ||
import random | ||
|
||
def calculate_yield(data: pd.DataFrame, jove_plus_line: int, jove_minus_line: int) -> pd.DataFrame: | ||
# Adjust the line numbers because of the header row (subtracting an additional 1 for zero-based indexing) | ||
jove_plus_index = jove_plus_line - 2 | ||
jove_minus_index = jove_minus_line - 2 | ||
|
||
# Get the autofluorescence and reference values based on user input | ||
autofluorescence = data.iloc[jove_minus_index].filter(like='Fluorescence').mean() | ||
reference = data.iloc[jove_plus_index].filter(like='Fluorescence').mean() | ||
|
||
# Create yield columns for each fluorescence value | ||
for col in data.columns: | ||
if 'Fluorescence' in col: | ||
yield_col = col.replace('Fluorescence', 'Yield') | ||
data[yield_col] = (data[col] - autofluorescence) / (reference - autofluorescence) | ||
|
||
return data | ||
|
||
def add_calibrated_yield(data: pd.DataFrame, a: float, b: float) -> pd.DataFrame: | ||
# Add "Calibrated Yield" columns for each "Yield" column | ||
for col in data.columns: | ||
if 'Yield' in col and 'Calibrated' not in col: | ||
calibrated_yield_col = col.replace('Yield', 'Calibrated Yield') | ||
data[calibrated_yield_col] = a * data[col] + b | ||
|
||
return data | ||
|
||
def fit_regression_with_outlier_removal(y: np.ndarray, y_ref: np.ndarray, r2_limit: float) -> tuple: | ||
max_outliers = int(0.3 * len(y)) # 30% of data points can be considered outliers | ||
current_r2 = 0 | ||
num_outliers_removed = 0 | ||
|
||
outlier_indices = [] | ||
|
||
while current_r2 <= r2_limit and num_outliers_removed < max_outliers: | ||
print(f"Current R²: {current_r2:.2f}, r2_limit: {r2_limit:.2f}, Outliers removed: {num_outliers_removed}") | ||
# Add a constant term for OLS | ||
X = sm.add_constant(y) | ||
model = sm.OLS(y_ref, X).fit() | ||
current_r2 = model.rsquared | ||
|
||
# Calculate Cook's distance | ||
influence = model.get_influence() | ||
cooks_d = influence.cooks_distance[0] | ||
|
||
# Identify the index of the maximum Cook's distance | ||
max_cooks_index = np.argmax(cooks_d) | ||
|
||
# Add the index to outlier list and remove it from the data | ||
outlier_indices.append(max_cooks_index) | ||
y = np.delete(y, max_cooks_index) | ||
y_ref = np.delete(y_ref, max_cooks_index) | ||
num_outliers_removed += 1 | ||
|
||
# Fit the final model | ||
final_model = sm.OLS(y_ref, sm.add_constant(y)).fit() | ||
a, b = final_model.params[1], final_model.params[0] | ||
r2_value = final_model.rsquared | ||
|
||
return a, b, r2_value, outlier_indices | ||
|
||
def select_control_points(data: pd.DataFrame, jove_plus_index: int, jove_minus_index: int, n: int) -> pd.DataFrame: | ||
# Find the index of the point with the highest yield | ||
max_yield_index = data.filter(like='Yield').mean(axis=1).idxmax() | ||
|
||
# Select Jove+, Jove-, and the point with the highest yield | ||
control_indices = {jove_plus_index, jove_minus_index, max_yield_index} | ||
|
||
# Select additional random points to reach n control points | ||
remaining_indices = list(set(data.index) - control_indices) | ||
random_indices = random.sample(remaining_indices, n - 3) | ||
control_indices.update(random_indices) | ||
|
||
# Return the DataFrame with the selected control points | ||
return data.loc[list(control_indices)] | ||
|
||
def plot_calibrated_points(y: np.ndarray, y_ref: np.ndarray, outlier_indices: list, a: float, b: float, r2_value: float, output_file: str, input_filename: str, ref_filename: str): | ||
# Plot the calibrated points in blue and outliers in red | ||
plt.figure(figsize=(10, 6)) | ||
plt.scatter(y, y_ref, color='blue', label='Calibrated Points') | ||
if outlier_indices: | ||
plt.scatter(y[outlier_indices], y_ref[outlier_indices], color='red', label='Removed Outliers') | ||
# Plot the regression line ax + b | ||
x_vals = np.array([min(y), max(y)]) | ||
y_vals = a * x_vals + b | ||
plt.plot(x_vals, y_vals, color='green', label=f'Regression Line: y = {a:.2f}x + {b:.2f}, R² = {r2_value:.2f}') | ||
# Add axis labels with filenames | ||
plt.xlabel(f'Calibrated Yield ({os.path.basename(input_filename)})') | ||
plt.ylabel(f'Reference Yield ({os.path.basename(ref_filename)})') | ||
plt.title('Calibrated Points with Outliers Removed and Regression Line') | ||
plt.legend() | ||
plt.savefig(output_file, format='png') | ||
plt.close() | ||
|
||
def detect_component_columns(data: pd.DataFrame) -> list: | ||
# Detect columns that appear before the first "Fluorescence" column | ||
component_columns = [] | ||
for col in data.columns: | ||
if 'Fluorescence' in col: | ||
break | ||
component_columns.append(col) | ||
return component_columns | ||
|
||
def find_matching_indices(input_data: pd.DataFrame, ref_data: pd.DataFrame, component_columns: list, rounding_precision: int = 2) -> tuple: | ||
# Round component values before matching | ||
input_combinations = input_data[component_columns].round(rounding_precision).apply(tuple, axis=1) | ||
ref_combinations = ref_data[component_columns].round(rounding_precision).apply(tuple, axis=1) | ||
|
||
# Convert reference combinations to a set for efficient matching | ||
ref_combinations_set = set(ref_combinations) | ||
|
||
# Find indices where the component combinations match | ||
matching_input_indices = [] | ||
matching_ref_indices = [] | ||
|
||
for i, combination in enumerate(input_combinations): | ||
if combination in ref_combinations_set: | ||
# Find the corresponding index in the reference data | ||
ref_index = ref_combinations[ref_combinations == combination].index[0] | ||
matching_input_indices.append(i) | ||
matching_ref_indices.append(ref_index) | ||
|
||
return matching_input_indices, matching_ref_indices | ||
|
||
def compute_average_yields(modified_data: pd.DataFrame, ref_data: pd.DataFrame, matching_input_indices: list, matching_ref_indices: list) -> tuple: | ||
# Calculate the average yield for the matching component combinations | ||
avg_yield = modified_data.filter(like='Yield').iloc[matching_input_indices].mean(axis=1).values | ||
avg_yield_ref = ref_data.filter(like='Yield').iloc[matching_ref_indices].mean(axis=1).values | ||
return avg_yield, avg_yield_ref | ||
|
||
def load_data(file_path: str) -> pd.DataFrame: | ||
# Load data based on file extension | ||
if file_path.endswith('.xlsx'): | ||
return pd.read_excel(file_path, sheet_name=0) # Read the first sheet | ||
elif file_path.endswith('.csv'): | ||
return pd.read_csv(file_path) | ||
else: | ||
raise ValueError("Unsupported file format. Please provide a .csv or .xlsx file.") | ||
|
||
def save_data(data: pd.DataFrame, output_file: str): | ||
# Save data based on file extension | ||
if output_file.endswith('.xlsx'): | ||
data.to_excel(output_file, index=False) | ||
elif output_file.endswith('.csv'): | ||
data.to_csv(output_file, index=False) | ||
else: | ||
raise ValueError("Unsupported file format. Please specify a .csv or .xlsx output file.") | ||
|
||
if __name__ == "__main__": | ||
# Set up argument parsing | ||
parser = argparse.ArgumentParser(description='Calculate yield based on fluorescence data and optionally apply calibration.') | ||
parser.add_argument('--file', type=str, required=True, help='Path to the input file (.csv or .xlsx)') | ||
parser.add_argument('--jove_plus', type=int, required=True, help='Line number for Jove+ (1-based index)') | ||
parser.add_argument('--jove_minus', type=int, required=True, help='Line number for Jove- (1-based index)') | ||
parser.add_argument('--r2_limit', type=float, default=0.8, help='R-squared limit for the regression (default: 0.8)') | ||
parser.add_argument('--ref_file', type=str, help='Path to the reference input file (.csv or .xlsx)') | ||
parser.add_argument('--output', type=str, required=True, help='Output file name (.csv or .xlsx)') | ||
parser.add_argument('--plot', type=str, help='Output PNG file name for the plot of calibrated points') | ||
parser.add_argument('--num_control_points', type=int, default=10, help='Number of control points to select (default: 5)') | ||
|
||
args = parser.parse_args() | ||
|
||
# Load the data from the input file | ||
input_data = load_data(args.file) | ||
|
||
# Calculate the yield and get the modified DataFrame | ||
modified_data = calculate_yield(input_data, args.jove_plus, args.jove_minus) | ||
|
||
# Detect component columns | ||
component_columns = detect_component_columns(modified_data) | ||
|
||
# Check if a reference file is provided | ||
if args.ref_file: | ||
# Load the reference data | ||
ref_data = load_data(args.ref_file) | ||
|
||
# Find matching indices based on component combinations | ||
matching_input_indices, matching_ref_indices = find_matching_indices(modified_data, ref_data, component_columns) | ||
|
||
# Compute average yields for matching component combinations | ||
avg_yield, avg_yield_ref = compute_average_yields(modified_data, ref_data, matching_input_indices, matching_ref_indices) | ||
|
||
# Fit the regression with outlier removal on average yields | ||
a, b, r2_value, outlier_indices = fit_regression_with_outlier_removal(avg_yield, avg_yield_ref, args.r2_limit) | ||
|
||
# Display the regression coefficients and R² value in the terminal | ||
print(f"Regression Line: y = {a:.2f}x + {b:.2f}") | ||
print(f"R² Value: {r2_value:.2f}") | ||
|
||
# Add calibrated yield columns | ||
calibrated_data = add_calibrated_yield(modified_data, a, b) | ||
|
||
# Plot the calibrated points with outliers and regression line if requested | ||
if args.plot: | ||
plot_calibrated_points(avg_yield, avg_yield_ref, outlier_indices, a, b, r2_value, args.plot, args.file, args.ref_file) | ||
else: | ||
# If no reference file is provided, just use the original | ||
calibrated_data = modified_data | ||
|
||
# Save the modified DataFrame to the specified output file | ||
save_data(calibrated_data, args.output) | ||
print(f"Calibrated yields saved in {args.output}") | ||
if args.plot: | ||
print(f"Plot saved as {args.plot}") | ||
|
||
# Select control points | ||
jove_plus_index = args.jove_plus - 2 | ||
jove_minus_index = args.jove_minus - 2 | ||
control_data = select_control_points(modified_data, jove_plus_index, jove_minus_index, args.num_control_points) | ||
|
||
# Save the new control points | ||
outf = os.path.splitext(args.output)[0] + '_control_points.csv' | ||
save_data(control_data, outf) | ||
print(f"New control points saved in {outf}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
param,value | ||
data_folder,data/top50 | ||
name_list,"Yield1,Yield2,Yield3,Yield4,Yield5" | ||
parameter_file,param.tsv | ||
nb_rep,100 | ||
flatten,False | ||
seed,85 | ||
nb_new_data_predict,3000 | ||
nb_new_data,50 | ||
parameter_step,20 | ||
strategy,ei | ||
theta,10 | ||
save_name,new_exp/plate3 | ||
n_group,15 | ||
km,50 | ||
ks,20 | ||
plot,True | ||
test,1 | ||
|
Oops, something went wrong.