Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
Created using spr 1.3.5
  • Loading branch information
mtrofin committed Feb 12, 2025
2 parents 4fb62bd + 2451eb1 commit 454b79c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
6 changes: 3 additions & 3 deletions compiler_opt/rl/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def kill(self):
'type': 'float',
},
})
writer.write_nl()
writer.write_newline()

class MockInteractiveProcess(MockProcess):
"""Mock clang interactive process that writes the log."""
Expand All @@ -128,10 +128,10 @@ def poll(self):
example_writer.write_context_marker(f'context_{self._counter}')
example_writer.write_observation_marker(0)
example_writer.write_buff([self._counter], ctypes.c_int64)
example_writer.write_nl()
example_writer.write_newline()
example_writer.write_outcome_marker(0)
example_writer.write_buff([3.14], ctypes.c_float)
example_writer.write_nl()
example_writer.write_newline()
self._counter += 1
return None

Expand Down
38 changes: 21 additions & 17 deletions compiler_opt/rl/log_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def json_to_bytes(d) -> bytes:
class LogTestExampleBuilder:
"""Construct a log."""

nl = '\n'.encode('utf-8')
error_nl = 'hi there'.encode('utf-8')
newline = b'\n'
error_newline = b'hi there'

class ErrorMarkers(enum.IntEnum):
NONE = 0
Expand All @@ -58,28 +58,28 @@ def __init__(

def write_buff(self, buffer: list, ct):
# we should get the ctypes array to bytes for pytype to be happy.
if self._introduce_error_pos == \
LogTestExampleBuilder.ErrorMarkers.TENSOR_BUF_POS:
if self._introduce_error_pos == LogTestExampleBuilder.ErrorMarkers.TENSOR_BUF_POS: # pylint: disable=line-too-long
buffer = buffer[len(buffer) // 2:]
# pytype:disable=wrong-arg-types
self._opened_file.write((ct * len(buffer))(*buffer))
# pytype:enable=wrong-arg-types

def write_nl(self, position=None):
self._opened_file.write(LogTestExampleBuilder.error_nl if position == self
._introduce_error_pos else LogTestExampleBuilder.nl)
def write_newline(self, position=None):
self._opened_file.write(
LogTestExampleBuilder.error_newline if position ==
self._introduce_error_pos else LogTestExampleBuilder.newline)

def write_context_marker(self, name: str):
self._opened_file.write(json_to_bytes({'context': name}))
self.write_nl(LogTestExampleBuilder.ErrorMarkers.CTX_MARKER_POS)
self.write_newline(LogTestExampleBuilder.ErrorMarkers.CTX_MARKER_POS)

def write_observation_marker(self, obs_idx: int):
self._opened_file.write(json_to_bytes({'observation': obs_idx}))
self.write_nl(LogTestExampleBuilder.ErrorMarkers.OBS_MARKER_POS)
self.write_newline(LogTestExampleBuilder.ErrorMarkers.OBS_MARKER_POS)

def write_outcome_marker(self, obs_idx: int):
self._opened_file.write(json_to_bytes({'outcome': obs_idx}))
self.write_nl(LogTestExampleBuilder.ErrorMarkers.OUTCOME_MARKER_POS)
self.write_newline(LogTestExampleBuilder.ErrorMarkers.OUTCOME_MARKER_POS)

def write_header(self, json_header: dict):
self._opened_file.write(json_to_bytes(json_header))
Expand Down Expand Up @@ -116,7 +116,8 @@ def create_example(fname: str,
'type': 'float'
}
})
example_writer.write_nl(LogTestExampleBuilder.ErrorMarkers.AFTER_HEADER)
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.AFTER_HEADER)
for ctx_id in range(nr_contexts):
t0_val = [v + ctx_id * 10 for v in t0_val]
t1_val = [v + ctx_id * 10 for v in t1_val]
Expand All @@ -126,10 +127,12 @@ def write_example_obs(obs: int):
example_writer.write_observation_marker(obs)
example_writer.write_buff(t0_val, ctypes.c_float)
example_writer.write_buff(t1_val, ctypes.c_int64)
example_writer.write_nl(LogTestExampleBuilder.ErrorMarkers.TENSORS_POS)
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.TENSORS_POS)
example_writer.write_outcome_marker(obs)
example_writer.write_buff(s, ctypes.c_float)
example_writer.write_nl(LogTestExampleBuilder.ErrorMarkers.OUTCOME_POS)
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.OUTCOME_POS)

write_example_obs(0)
t0_val = [v + 1 for v in t0_val]
Expand Down Expand Up @@ -277,9 +280,10 @@ def test_seq_example_conversion(self):

def test_errors(self):
logfile = self.create_tempfile()
for i in range(1, len(LogTestExampleBuilder.ErrorMarkers)):
create_example(
logfile, introduce_errors_pos=LogTestExampleBuilder.ErrorMarkers(i))
for error_marker in LogTestExampleBuilder.ErrorMarkers:
if not error_marker:
continue
create_example(logfile, introduce_errors_pos=error_marker)
with self.assertRaises(Exception):
log_reader.read_log_as_sequence_examples(logfile)

Expand All @@ -301,7 +305,7 @@ def test_truncated_tensors(self):
'type': 'float'
}
})
writer.write_nl()
writer.write_newline()
writer.write_context_marker('whatever')
writer.write_observation_marker(0)
writer.write_buff([1], ctypes.c_int16)
Expand Down

0 comments on commit 454b79c

Please sign in to comment.