Skip to content

Commit

Permalink
fix datatype
Browse files Browse the repository at this point in the history
  • Loading branch information
GemmaTuron committed Nov 30, 2023
1 parent c370cd5 commit 94bf511
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
18 changes: 18 additions & 0 deletions ersilia/io/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def __meta_by_key(self, k):
def __cast_values(self, vals, dtypes, output_keys):
v = []
for v_, t_, k_ in zip(vals, dtypes, output_keys):
self.logger.debug(v_)
self.logger.debug(t_)
self.logger.debug(k_)
if t_ in self._array_types:
if v_ is None:
v_ = [None] * self.__array_shape(k_)
Expand Down Expand Up @@ -196,24 +199,31 @@ def __expand_output_keys(self, vals, output_keys):
self.logger.debug("Values: {0}".format(v))
m = self.__meta_by_key(ok)
if ok not in current_pure_dtype:
self.logger.debug("Getting pure dtype for {0}".format(ok))
t = self.__pure_dtype(ok)
self.logger.debug("This is the pure datatype: {0}".format(t))
if t is None:
t = self._guess_pure_dtype_if_absent(v)
self.logger.debug("Guessed absent pure datatype: {0}".format(t))
current_pure_dtype[ok] = t
else:
t = current_pure_dtype[ok]
self.logger.debug("Datatype: {0}".format(t))
if t in self._array_types:
self.logger.debug("Datatype has been matched: {0} over {1}".format(t, self._array_types))
assert m is not None
if v is not None:
if len(m) > len(v):
self.logger.debug("Metadata {0} is longer than values {1}".format(len(m), len(v)))
v = list(v) + [None] * (len(m) - len(v))
assert len(m) == len(v)
if merge_key:
self.logger.debug("Merge key is {0}".format(merge_key))
output_keys_expanded += [
"{0}{1}{2}".format(ok, FEATURE_MERGE_PATTERN, m_) for m_ in m
]
else:
self.logger.debug("No merge key")
output_keys_expanded += ["{0}".format(m_) for m_ in m]
else:
output_keys_expanded += [ok]
Expand Down Expand Up @@ -248,8 +258,16 @@ def _to_dataframe(self, result, model_id):
output_keys = [k for k in out.keys()]
vals = [out[k] for k in output_keys]
dtypes = [self.__pure_dtype(k) for k in output_keys]
are_dtypes_informative = False
for dtype in dtypes:
if dtype is not None:
are_dtypes_informative = True
if output_keys_expanded is None:
output_keys_expanded = self.__expand_output_keys(vals, output_keys)
if not are_dtypes_informative:
t = self._guess_pure_dtype_if_absent(vals)
if len(output_keys) == 1:
dtypes = [t]
vals = self.__cast_values(vals, dtypes, output_keys)
R += [[inp["key"], inp["input"]] + vals]
columns = ["key", "input"] + output_keys_expanded
Expand Down
3 changes: 3 additions & 0 deletions ersilia/io/pure.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def _is_numeric_array(self):
def _is_string_array(self):
if self._is_array():
data = np.array(self.data).ravel().tolist()
print(len(data))
data = [x for x in data if x is not None]
print(len(data))
if len(data) < 1:
return False
for x in data:
if not PureDataTyper(x)._is_string():
print(x, "HERE")
return False
return True
else:
Expand Down

0 comments on commit 94bf511

Please sign in to comment.