diff --git a/qcportal/qcportal/record_models.py b/qcportal/qcportal/record_models.py index cd1de8070..bbfaa514c 100644 --- a/qcportal/qcportal/record_models.py +++ b/qcportal/qcportal/record_models.py @@ -108,38 +108,35 @@ class Config: output_type: OutputTypeEnum = Field(..., description="The type of output this is (stdout, error, etc)") compression_type: CompressionEnum = Field(CompressionEnum.none, description="Compression method (such as lzma)") + data_: Optional[bytes] = None - data_url_: Optional[str] = None - compressed_data_: Optional[bytes] = None - decompressed_data_: Optional[Any] = None - + _data_url: Optional[str] = PrivateAttr(None) _client: Any = PrivateAttr(None) def propagate_client(self, client, history_base_url): self._client = client - self.data_url_ = f"{history_base_url}/outputs/{self.output_type.value}/data" + self._data_url = f"{history_base_url}/outputs/{self.output_type.value}/data" def _fetch_raw_data(self): - if self.compressed_data_ is None and self.decompressed_data_ is None: - cdata, ctype = self._client.make_request( - "get", - self.data_url_, - Tuple[bytes, CompressionEnum], - ) + if self.data_ is not None: + return + + if self._client is None: + raise RuntimeError("No client to fetch output data from") - self.compression_type = ctype - self.compressed_data_ = cdata + cdata, ctype = self._client.make_request( + "get", + self._data_url, + Tuple[bytes, CompressionEnum], + ) + + assert self.compression_type == ctype + self.data_ = cdata @property def data(self) -> Any: self._fetch_raw_data() - - # Decompress, then remove compressed form - if self.decompressed_data_ is None: - self.decompressed_data_ = decompress(self.compressed_data_, self.compression_type) - self.compressed_data_ = None - - return self.decompressed_data_ + return decompress(self.data_, self.compression_type) class ComputeHistory(BaseModel): @@ -217,38 +214,35 @@ class Config: name: str = Field(..., description="Name of the file") compression_type: CompressionEnum = Field(..., description="Compression method (such as lzma)") + data_: Optional[bytes] = None - data_url_: Optional[str] = None - compressed_data_: Optional[bytes] = None - decompressed_data_: Optional[Any] = None - + _data_url: Optional[str] = PrivateAttr(None) _client: Any = PrivateAttr(None) def propagate_client(self, client, record_base_url): self._client = client - self.data_url_ = f"{record_base_url}/native_files/{self.name}/data" + self._data_url = f"{record_base_url}/native_files/{self.name}/data" def _fetch_raw_data(self): - if self.compressed_data_ is None and self.decompressed_data_ is None: - cdata, ctype = self._client.make_request( - "get", - self.data_url_, - Tuple[bytes, CompressionEnum], - ) + if self.data_ is not None: + return + + if self._client is None: + raise RuntimeError("No client to fetch native file data from") - assert self.compression_type == ctype - self.compressed_data_ = cdata + cdata, ctype = self._client.make_request( + "get", + self._data_url, + Tuple[bytes, CompressionEnum], + ) + + assert self.compression_type == ctype + self.data_ = cdata @property def data(self) -> Any: self._fetch_raw_data() - - # Decompress, then remove compressed form - if self.decompressed_data_ is None: - self.decompressed_data_ = decompress(self.compressed_data_, self.compression_type) - self.compressed_data_ = None - - return self.decompressed_data_ + return decompress(self.data_, self.compression_type) def save_file( self, directory: str, new_name: Optional[str] = None, keep_compressed: bool = False, overwrite: bool = False diff --git a/qcportal/qcportal/singlepoint/record_models.py b/qcportal/qcportal/singlepoint/record_models.py index b582e759a..319b5a82f 100644 --- a/qcportal/qcportal/singlepoint/record_models.py +++ b/qcportal/qcportal/singlepoint/record_models.py @@ -66,39 +66,36 @@ class Config: extra = Extra.forbid compression_type: CompressionEnum + data_: Optional[bytes] = None - data_url_: Optional[str] = None - compressed_data_: Optional[bytes] = None - decompressed_data_: Optional[WavefunctionProperties] = None - + _data_url: Optional[str] = PrivateAttr(None) _client: Any = PrivateAttr(None) def propagate_client(self, client, record_base_url): self._client = client - self.data_url_ = f"{record_base_url}/wavefunction/data" + self._data_url = f"{record_base_url}/wavefunction/data" def _fetch_raw_data(self): - if self.compressed_data_ is None and self.decompressed_data_ is None: - cdata, ctype = self._client.make_request( - "get", - self.data_url_, - Tuple[bytes, CompressionEnum], - ) + if self.data_ is not None: + return + + if self._client is None: + raise RuntimeError("No client to fetch wavefunction data from") + + cdata, ctype = self._client.make_request( + "get", + self._data_url, + Tuple[bytes, CompressionEnum], + ) - assert self.compression_type == ctype - self.compressed_data_ = cdata + assert self.compression_type == ctype + self.data_ = cdata @property def data(self) -> WavefunctionProperties: self._fetch_raw_data() - - # Decompress, then remove compressed form - if self.decompressed_data_ is None: - wfn_dict = decompress(self.compressed_data_, self.compression_type) - self.decompressed_data_ = WavefunctionProperties(**wfn_dict) - self.compressed_data_ = None - - return self.decompressed_data_ + wfn_dict = decompress(self.data_, self.compression_type) + return WavefunctionProperties(**wfn_dict) class SinglepointRecord(BaseRecord):