diff --git a/src/anemoi/inference/runner.py b/src/anemoi/inference/runner.py index 79cb75d..f8768e0 100644 --- a/src/anemoi/inference/runner.py +++ b/src/anemoi/inference/runner.py @@ -119,12 +119,33 @@ def run( LOGGER.info("Loading input: %d fields (lagged=%d)", len(input_fields), len(self.lagged)) - input_fields_numpy = input_fields.to_numpy(dtype=np.float32) + if start_datetime is None: + start_datetime = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime") + + num_fields_per_date = len(input_fields) // len(self.lagged) # assumed + + # Check valid_datetime of input data + # The subsequent reshape operation assumes that input_fields are chunkable by datetime + for i, lag in enumerate(self.lagged): + date = start_datetime + datetime.timedelta(hours=lag) + dates_found = set( + field.datetime() for field in input_fields[i * num_fields_per_date : (i + 1) * num_fields_per_date] + ) + # All chunks must have the same datetime that must agree with the lag + if dates_found != {date}: + raise RuntimeError( + "Inconsistent datetimes detected.\n" + f"Datetimes in data: {', '.join(d.isoformat() for d in dates_found)}.\n" + f"Expected datetime: {date.isoformat()} (for lag {lag})" + ) + + input_fields_numpy = input_fields.to_numpy(dtype=np.float32, reshape=False) + print(input_fields_numpy.shape) input_fields_numpy = input_fields_numpy.reshape( len(self.lagged), - len(input_fields) // len(self.lagged), + num_fields_per_date, number_of_grid_points, ) # nlags, nparams, ngrid @@ -223,10 +244,6 @@ def run( :, constant_data_from_retrieved_fields_mask ] - if start_datetime is None: - start_datetime_str = input_fields.order_by(valid_datetime="ascending")[-1].metadata("valid_datetime") - start_datetime = datetime.datetime.fromisoformat(start_datetime_str) - constants = forcing_and_constants( source=input_fields[:1], param=self.checkpoint.computed_constants,