-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_csv.py
229 lines (173 loc) · 8.23 KB
/
plot_csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import numpy as np
import re
import matplotlib.pyplot as plt
import os
from matplotlib import colors
import matplotlib.colors as mcolors
def get_csv_filepaths(directory, run_range=None):
"""Gets all CSV file paths in the specified directory and filters them based on run numbers.
Args:
directory (str): Path to the directory containing the CSV files.
run_range (tuple, optional): Tuple specifying the (min, max) range of run numbers to include.
Defaults to None (includes all files).
Returns:
list: List of file paths to filtered CSV files in the directory.
"""
csv_files = []
pattern = re.compile(r'calo(\d{1,5})-HR\.csv')
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".csv"):
match = pattern.match(file)
if match:
run_number = int(match.group(1)) # Extract the run number from the filename
# Debugging output: print matched filename and run number
#print(f"Found file: {file}, Run number: {run_number}")
# If run_range is specified, filter based on the run number
if run_range is None or (run_range[0] <= run_number <= run_range[1]):
print(f"Adding file: {file} (Run {run_number}) to the list")
csv_files.append(os.path.join(root, file))
else:
print(f"File: {file} (Run {run_number}) outside range {run_range}")
print(f"Total files found in range: {len(csv_files)}")
return sorted(csv_files) # Sort files if needed for consistent order
def process_multiple_csvs(filepaths):
"""Reads in multiple CSV files, creates energy matrices, and computes the average matrix.
Args:
filepaths (list): List of file paths to CSV files.
Returns:
array: The averaged energy matrix.
"""
# Initialize variables for accumulating the matrices
total_matrix = None
num_files = len(filepaths)
if num_files == 0:
raise ValueError("No CSV files found in the specified range.")
# Loop through each file and accumulate the energy matrices
valid_files_count = 0 # Count how many valid files were processed
for i, filepath in enumerate(filepaths):
print(f"Processing file {i+1}/{num_files}: {filepath}")
energy_grid = transform_data_csv(filepath) # Get the matrix for the current CSV
if energy_grid is None:
print(f"Warning: Skipping file {filepath} because the matrix is None.")
continue
# Initialize the accumulator matrix on the first valid matrix
if total_matrix is None:
total_matrix = np.zeros_like(energy_grid)
# Add the current matrix to the accumulator
total_matrix += energy_grid
valid_files_count += 1
# Check if any valid files were processed
if valid_files_count == 0:
raise ValueError("No valid CSV files were processed. Please check the files or the format.")
# Compute the average by dividing the total matrix by the number of valid files
average_matrix = total_matrix / valid_files_count
return average_matrix
def downsample(grid, new_shape):
"""
Downsamples a 2D grid by averaging over blocks.
grid: original 2D array to downsample (e.g., 96x96)
new_shape: the desired shape after downsampling (e.g., 12x12)
"""
# Get the shape of the original grid
shape = grid.shape
factor_x = shape[0] // new_shape[0]
factor_y = shape[1] // new_shape[1]
# Reshape the grid into smaller blocks and average
downsampled_grid = grid.reshape(new_shape[0], factor_x, new_shape[1], factor_y).mean(axis=(1, 3))
return downsampled_grid
def transform_data_csv(filepath, isDownsample = False, new_size = None):
"""This function takes a filepath to a csv file and transforms it to a plottable grid, to plot it
use functions like pyplot.imshow(). Is able to downsample the data blockwise if wanted.
Args:
filepath (string): Path to the .csv file to be used
downsample (bool, optional): Specifies whether downsampling should be applied or not. Defaults to False.
new_size (tuple, optional): If downsampling is wanted, new size can be input to reshape the array. Defaults to None.
Returns:
array: Array containing the energy matrix.
"""
data = np.genfromtxt(filepath, delimiter=",")
data = data[:, [0, 1, 3]]
x = data[:, 0].astype(int)
y = data[:, 1].astype(int)
energy = data[:,2]
x_max = x.max() + 1
y_max = y.max() + 1
energy_grid = np.zeros((x_max, y_max))
for i in range(len(data)):
energy_grid[x[i], y[i]] = energy[i]
if isDownsample == True:
energy_grid = downsample(energy_grid, new_size)
return energy_grid
def basic_plotter(energy_grid, savepath, clean=False):
"""Basic plotting setup using pyplot imshow and the standard color scheme.
If clean is True, it removes axes labels, colorbars, and whitespace.
Args:
energy_grid (array): Array in a n x n shape.
savepath (string): String to save the plot as png, pdf...
clean (bool, optional): If True, generates a clean image with no axes or borders. Defaults to False.
"""
fig, ax = plt.subplots()
# Plot the energy grid
img = ax.imshow(energy_grid, origin='lower', norm=colors.LogNorm()) #, norm=colors.LogNorm()
if clean:
# Remove axes labels, ticks, and colorbar
ax.axis('off') # Turns off the axes completely
# Remove padding and margins (white borders)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
else:
# Add labels if not in clean mode
ax.set_xlabel("Detector x")
ax.set_ylabel("Detector y")
# Add colorbar
cbar = fig.colorbar(img, ax=ax)
cbar.set_label("Energy in MeV")
# Save the plot
fig.savefig(savepath, bbox_inches='tight', pad_inches=0 if clean else 0.1)
# Close the figure to free up memory
plt.close(fig)
def max_6x6_submatrix_sum(matrix):
n, m = matrix.shape
max_sum = float('-inf') # Smallest possible number
best_submatrix = None
# Iterate over all possible 6x6 submatrices
for i in range(n - 5):
for j in range(m - 5):
# Extract 6x6 submatrix
submatrix = matrix[i:i+6, j:j+6]
submatrix_sum = np.sum(submatrix)
# Update if a larger sum is found
if submatrix_sum > max_sum:
max_sum = submatrix_sum
best_submatrix = submatrix
return best_submatrix
def normalize_and_scale(matrix):
"""Normalizes the matrix and applies a power scaling with the exponent 0.3 (see https://arxiv.org/abs/2308.09025).
Args:
matrix (array): Input 2D array to normalize and scale.
Returns:
array: Normalized and power scaled matrix.
"""
matrix_min = np.min(matrix)
matrix_max = np.max(matrix)
if matrix_max - matrix_min == 0:
return np.zeros_like(matrix)
normalized_matrix = matrix / matrix_max
power_scaled_matrix = np.power(normalized_matrix, 0.3)
return power_scaled_matrix
def process_data(folder_path, run_range = None, plotting_save_path = None, clean_plot = False):
"""Extracts data from folder path, creates matrices and scales them according to https://arxiv.org/abs/2308.09025.
Creates plots if wished.
Args:
folder_path (str): Path to the folder where the data is located.
run_range (tuple, optional): Data indices to create matrices for. Defaults to None.
plotting_save_path (str, optional): Path to save the plots as .png files to. Defaults to None.
clean_plot (bool, optional): Creates plot without any labes or borders if set to True. Defaults to False.
"""
file_paths = get_csv_filepaths(folder_path, run_range)
average_matrix = process_multiple_csvs(file_paths)
best_submatrix = max_6x6_submatrix_sum(average_matrix)
final_matrix = normalize_and_scale(best_submatrix)
if plotting_save_path != None:
basic_plotter(final_matrix, plotting_save_path, clean_plot)
return