diff --git a/src/starfile/functions.py b/src/starfile/functions.py index 792f666..637268a 100644 --- a/src/starfile/functions.py +++ b/src/starfile/functions.py @@ -116,4 +116,4 @@ def to_string( quote_character=quote_character, quote_all_strings=quote_all_strings, ) - return ''.join(line + '\n' for line in writer.lines()) + return writer.to_string() diff --git a/src/starfile/writer.py b/src/starfile/writer.py index 5c5043d..39efd73 100644 --- a/src/starfile/writer.py +++ b/src/starfile/writer.py @@ -66,12 +66,14 @@ def lines(self) -> Generator[str, None, None]: yield '' for line in self.data_block_generator(): yield line + + def to_string(self) -> str: + return ''.join(line + '\n' for line in self.lines()) def write(self): if self.filename is None: raise ValueError('Cannot write nameless file!') - with open(self.filename, mode='w+') as f: - f.writelines(line + '\n' for line in self.lines()) + self.filename.write_text(self.to_string()) def data_block_generator(self) -> Generator[str, None, None]: for block_name, block in self.data_blocks.items(): @@ -176,7 +178,7 @@ def loop_block( float_format=float_format, na_rep=na_rep, quoting=csv.QUOTE_NONE - ).split('\n'): + ).splitlines(): yield line yield '' diff --git a/tests/test_parsing.py b/tests/test_parsing.py index bbe93fa..2ffbdb6 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -1,3 +1,4 @@ +from pathlib import Path import time import pandas as pd @@ -288,17 +289,15 @@ def test_parse_as_string(): assert df['rlnResolution'].dtype == 'object' -def test_parse_na(): - import tempfile +def test_parse_na(tmpdir): import starfile - parts = pd.DataFrame({"property1":np.arange(10), "property2": np.random.rand(10)}) + + parts = pd.DataFrame({"property1": np.arange(10), "property2": np.random.rand(10)}) parts["property2"].values[-1] *= np.nan data = { - "particles":parts + "particles": parts } - - with tempfile.NamedTemporaryFile(mode="w") as tmpfile: - starfile.write(data, tmpfile.name) - tmpfile.seek(0) - data = starfile.read(tmpfile.name) - assert data["property2"].dtype == "float64" + tmpfile = Path(tmpdir) / "temp.star" + starfile.write(data, tmpfile) + data = starfile.read(tmpfile) + assert data["property2"].dtype == "float64"