Skip to content

Commit

Permalink
fixed None
Browse files Browse the repository at this point in the history
  • Loading branch information
Marek Ozana committed Dec 10, 2024
1 parent 3711fd3 commit 98b759c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 11 deletions.
47 changes: 36 additions & 11 deletions polars_bloomberg/plbbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def bql(self, expression: str) -> pl.DataFrame:
request = self._create_bql_request(expression)
responses = self._send_request(request)
data, schema = self._parse_bql_responses(responses)
return pl.DataFrame(data, schema=schema)
return pl.DataFrame(data, schema=schema, strict=True)

def _create_request(
self,
Expand Down Expand Up @@ -315,26 +315,51 @@ def _parse_bql_responses(self, responses: list[Any]):
data.setdefault(col_name, []).extend(values)
all_column_types.update(col_types)

# Map string types to Polars data types
schema = self._map_column_types_to_schema(all_column_types)
data = self._convert_dates_and_handle_nans(data, schema)

return data, schema

def _convert_dates_and_handle_nans(self, data, schema):
fmt = "%Y-%m-%dT%H:%M:%SZ"
for col, values in data.items():
if schema.get(col) == pl.Date:
data[col] = data[col] = [
# 'v' can be None, need to handle it
datetime.strptime(v, fmt).date() if isinstance(v, str) else None
for v in values
]
elif schema.get(col) in [pl.Float64, pl.Int64]:
data[col] = [None if x == "NaN" else x for x in values]
return data

def _map_column_types_to_schema(
self, all_column_types: dict[str, str]
) -> dict[str, pl.DataType]:
"""Map column types from string representation to Polars data types.
Parameters
----------
all_column_types : dict[str, str]
A dictionary mapping column names to their string type representations.
Returns
-------
dict
A dictionary mapping column names to Polars data types.
"""
type_mapping = {
"STRING": pl.Utf8,
"DOUBLE": pl.Float64,
"INT": pl.Int64,
"DATE": pl.Date,
}
schema = {
return {
col_name: type_mapping.get(col_type, pl.Utf8)
for col_name, col_type in all_column_types.items()
}

# Convert date strings to date objects
fmt = "%Y-%m-%dT%H:%M:%SZ"
for col, values in data.items():
if schema.get(col) == pl.Date:
data[col] = [datetime.strptime(v, fmt).date() for v in values]

return data, schema

def _parse_bql_response_dict(self, results: dict[str, Any]):
"""Parse BQL response dictionary into a table format.
Expand Down
69 changes: 69 additions & 0 deletions tests/test_plbbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,72 @@ def test_send_request_with_response_error(self, bquery):
# Assertions
bquery.session.sendRequest.assert_called_with(mock_request)
bquery.session.nextEvent.assert_called_once_with(5000)


class TestSchemaMapping:
@pytest.fixture
def bq(self):
"""Fixture to create a BQuery instance for testing."""
return BQuery()

def test_map_all_known_types(self, bq: BQuery):
"""Test mapping with all known column types."""
input_types = {
"name": "STRING",
"price": "DOUBLE",
"quantity": "INT",
"transaction_date": "DATE",
}
expected_schema = {
"name": pl.Utf8,
"price": pl.Float64,
"quantity": pl.Int64,
"transaction_date": pl.Date,
}
result = bq._map_column_types_to_schema(input_types)
assert result == expected_schema

def test_map_with_unknown_types(self, bq: BQuery):
"""Test mapping with some unknown column types."""
input_types = {
"name": "STRING",
"price": "DOUBLE",
"status": "BOOLEAN", # Unknown type
"transaction_date": "DATE",
}
expected_schema = {
"name": pl.Utf8,
"price": pl.Float64,
"status": pl.Utf8, # Defaults to Utf8
"transaction_date": pl.Date,
}
result = bq._map_column_types_to_schema(input_types)
assert result == expected_schema

def test_map_empty_input(self, bq: BQuery):
"""Test mapping with empty input dictionary."""
input_types = {}
expected_schema = {}
result = bq._map_column_types_to_schema(input_types)
assert result == expected_schema

def test_map_mixed_types(self, bq: BQuery):
"""Test mapping with a mix of known and unknown column types."""
input_types = {
"name": "STRING",
"price": "DOUBLE",
"status": "BOOLEAN",
"quantity": "INT",
"transaction_date": "DATE",
"extra_field": "UNKNOWN_TYPE",
}
expected_schema = {
"name": pl.Utf8,
"price": pl.Float64,
"status": pl.Utf8, # Defaults to Utf8
"quantity": pl.Int64,
"transaction_date": pl.Date,
"extra_field": pl.Utf8, # Defaults to Utf8
}
result = bq._map_column_types_to_schema(input_types)
assert result == expected_schema

0 comments on commit 98b759c

Please sign in to comment.