Skip to content

Commit

Permalink
Remove temporary caching of decompressed output/native files/wfn
Browse files Browse the repository at this point in the history
  • Loading branch information
bennybp committed Jan 13, 2024
1 parent f5f08b4 commit 89da752
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 61 deletions.
74 changes: 34 additions & 40 deletions qcportal/qcportal/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,38 +108,35 @@ class Config:

output_type: OutputTypeEnum = Field(..., description="The type of output this is (stdout, error, etc)")
compression_type: CompressionEnum = Field(CompressionEnum.none, description="Compression method (such as lzma)")
data_: Optional[bytes] = None

data_url_: Optional[str] = None
compressed_data_: Optional[bytes] = None
decompressed_data_: Optional[Any] = None

_data_url: Optional[str] = PrivateAttr(None)
_client: Any = PrivateAttr(None)

def propagate_client(self, client, history_base_url):
self._client = client
self.data_url_ = f"{history_base_url}/outputs/{self.output_type.value}/data"
self._data_url = f"{history_base_url}/outputs/{self.output_type.value}/data"

def _fetch_raw_data(self):
if self.compressed_data_ is None and self.decompressed_data_ is None:
cdata, ctype = self._client.make_request(
"get",
self.data_url_,
Tuple[bytes, CompressionEnum],
)
if self.data_ is not None:
return

if self._client is None:
raise RuntimeError("No client to fetch output data from")

self.compression_type = ctype
self.compressed_data_ = cdata
cdata, ctype = self._client.make_request(
"get",
self._data_url,
Tuple[bytes, CompressionEnum],
)

assert self.compression_type == ctype
self.data_ = cdata

@property
def data(self) -> Any:
self._fetch_raw_data()

# Decompress, then remove compressed form
if self.decompressed_data_ is None:
self.decompressed_data_ = decompress(self.compressed_data_, self.compression_type)
self.compressed_data_ = None

return self.decompressed_data_
return decompress(self.data_, self.compression_type)


class ComputeHistory(BaseModel):
Expand Down Expand Up @@ -217,38 +214,35 @@ class Config:

name: str = Field(..., description="Name of the file")
compression_type: CompressionEnum = Field(..., description="Compression method (such as lzma)")
data_: Optional[bytes] = None

data_url_: Optional[str] = None
compressed_data_: Optional[bytes] = None
decompressed_data_: Optional[Any] = None

_data_url: Optional[str] = PrivateAttr(None)
_client: Any = PrivateAttr(None)

def propagate_client(self, client, record_base_url):
self._client = client
self.data_url_ = f"{record_base_url}/native_files/{self.name}/data"
self._data_url = f"{record_base_url}/native_files/{self.name}/data"

def _fetch_raw_data(self):
if self.compressed_data_ is None and self.decompressed_data_ is None:
cdata, ctype = self._client.make_request(
"get",
self.data_url_,
Tuple[bytes, CompressionEnum],
)
if self.data_ is not None:
return

if self._client is None:
raise RuntimeError("No client to fetch native file data from")

assert self.compression_type == ctype
self.compressed_data_ = cdata
cdata, ctype = self._client.make_request(
"get",
self._data_url,
Tuple[bytes, CompressionEnum],
)

assert self.compression_type == ctype
self.data_ = cdata

@property
def data(self) -> Any:
self._fetch_raw_data()

# Decompress, then remove compressed form
if self.decompressed_data_ is None:
self.decompressed_data_ = decompress(self.compressed_data_, self.compression_type)
self.compressed_data_ = None

return self.decompressed_data_
return decompress(self.data_, self.compression_type)

def save_file(
self, directory: str, new_name: Optional[str] = None, keep_compressed: bool = False, overwrite: bool = False
Expand Down
39 changes: 18 additions & 21 deletions qcportal/qcportal/singlepoint/record_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,36 @@ class Config:
extra = Extra.forbid

compression_type: CompressionEnum
data_: Optional[bytes] = None

data_url_: Optional[str] = None
compressed_data_: Optional[bytes] = None
decompressed_data_: Optional[WavefunctionProperties] = None

_data_url: Optional[str] = PrivateAttr(None)
_client: Any = PrivateAttr(None)

def propagate_client(self, client, record_base_url):
self._client = client
self.data_url_ = f"{record_base_url}/wavefunction/data"
self._data_url = f"{record_base_url}/wavefunction/data"

def _fetch_raw_data(self):
if self.compressed_data_ is None and self.decompressed_data_ is None:
cdata, ctype = self._client.make_request(
"get",
self.data_url_,
Tuple[bytes, CompressionEnum],
)
if self.data_ is not None:
return

if self._client is None:
raise RuntimeError("No client to fetch wavefunction data from")

cdata, ctype = self._client.make_request(
"get",
self._data_url,
Tuple[bytes, CompressionEnum],
)

assert self.compression_type == ctype
self.compressed_data_ = cdata
assert self.compression_type == ctype
self.data_ = cdata

@property
def data(self) -> WavefunctionProperties:
self._fetch_raw_data()

# Decompress, then remove compressed form
if self.decompressed_data_ is None:
wfn_dict = decompress(self.compressed_data_, self.compression_type)
self.decompressed_data_ = WavefunctionProperties(**wfn_dict)
self.compressed_data_ = None

return self.decompressed_data_
wfn_dict = decompress(self.data_, self.compression_type)
return WavefunctionProperties(**wfn_dict)


class SinglepointRecord(BaseRecord):
Expand Down

0 comments on commit 89da752

Please sign in to comment.