Skip to content

Commit

Permalink
log_reader: tighten parsing (#438)
Browse files Browse the repository at this point in the history
The log format has a few places where we insert `\n` for human readability. They should be checked they are indeed just that character. 

Similarly, checking that the tensor data received matches in size what was expected.

Refactored a bit the test utility for constructing examples.
  • Loading branch information
mtrofin authored Feb 13, 2025
1 parent a8559c1 commit b8aaa54
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 93 deletions.
49 changes: 25 additions & 24 deletions compiler_opt/rl/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,23 @@ def kill(self):
with io.FileIO(fname + '.out', 'wb+') as f_out:
with io.FileIO(fname + '.in', 'rb+') as f_in:
del f_in
writer = log_reader_test.LogTestExampleBuilder(opened_file=f_out)
# Write the header describing the features/rewards
f_out.write(
log_reader_test.json_to_bytes({
'features': [{
'name': 'times_called',
'port': 0,
'shape': [1],
'type': 'int64_t',
},],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float',
},
}))
log_reader_test.write_nl(f_out)
writer.write_header({
'features': [{
'name': 'times_called',
'port': 0,
'shape': [1],
'type': 'int64_t',
},],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float',
},
})
writer.write_newline()

class MockInteractiveProcess(MockProcess):
"""Mock clang interactive process that writes the log."""
Expand All @@ -120,14 +120,15 @@ def poll(self):
if self._counter >= _NUM_STEPS:
f_out.close()
return None
log_reader_test.write_context_marker(f_out,
f'context_{self._counter}')
log_reader_test.write_observation_marker(f_out, 0)
log_reader_test.write_buff(f_out, [self._counter], ctypes.c_int64)
log_reader_test.write_nl(f_out)
log_reader_test.write_outcome_marker(f_out, 0)
log_reader_test.write_buff(f_out, [3.14], ctypes.c_float)
log_reader_test.write_nl(f_out)
example_writer = log_reader_test.LogTestExampleBuilder(
opened_file=f_out)
example_writer.write_context_marker(f'context_{self._counter}')
example_writer.write_observation_marker(0)
example_writer.write_buff([self._counter], ctypes.c_int64)
example_writer.write_newline()
example_writer.write_outcome_marker(0)
example_writer.write_buff([3.14], ctypes.c_float)
example_writer.write_newline()
self._counter += 1
return None

Expand Down
18 changes: 15 additions & 3 deletions compiler_opt/rl/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ class _Header:
def _read_tensor(fs: BinaryIO, ts: tf.TensorSpec) -> LogReaderTensorValue:
size = math.prod(ts.shape) * ctypes.sizeof(_dtype_to_ctype[ts.dtype])
data = fs.read(size)
if len(data) != size:
raise IOError(
f'Expected to read a total of {size} bytes for tensors, got {len(data)}'
)
return LogReaderTensorValue(ts, data)


Expand Down Expand Up @@ -175,20 +179,28 @@ def _enumerate_log_from_stream(
tensor_specs = header.features
score_spec = header.score
context = None

def expect_newline():
expected = f.readline().decode('utf-8')
if '\n' != expected:
raise IOError(f'Expected newline in log stream, got {expected}')

while event_str := f.readline():
event = json.loads(event_str)
if 'context' in event:
context = event['context']
continue
observation_id = int(event['observation'])
features = [_read_tensor(f, ts) for ts in tensor_specs]
f.readline()
expect_newline()
score = None
if score_spec is not None:
score_header = json.loads(f.readline())
assert int(score_header['outcome']) == observation_id
if int(score_header['outcome']) != observation_id:
raise IOError(f'Expected observation ID {observation_id} \
got {score_header["outcome"]}')
score = _read_tensor(f, score_spec)
f.readline()
expect_newline()
yield ObservationRecord(
context=context,
observation_id=observation_id,
Expand Down
198 changes: 132 additions & 66 deletions compiler_opt/rl/log_reader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for compiler_opt.rl.log_reader."""

import ctypes
import enum
import json
from compiler_opt.rl import log_reader

Expand All @@ -30,83 +31,113 @@ def json_to_bytes(d) -> bytes:
return json.dumps(d).encode('utf-8')


nl = '\n'.encode('utf-8')


def write_buff(f: BinaryIO, buffer: list, ct):
# we should get the ctypes array to bytes for pytype to be happy.
f.write((ct * len(buffer))(*buffer)) # pytype:disable=wrong-arg-types


def write_context_marker(f: BinaryIO, name: str):
f.write(json_to_bytes({'context': name}))
f.write(nl)


def write_observation_marker(f: BinaryIO, obs_idx: int):
f.write(json_to_bytes({'observation': obs_idx}))
f.write(nl)


def write_nl(f: BinaryIO):
f.write(nl)


def write_outcome_marker(f: BinaryIO, obs_idx: int):
f.write(json_to_bytes({'outcome': obs_idx}))
f.write(nl)


def create_example(fname: str, nr_contexts=1):
class LogTestExampleBuilder:
"""Construct a log."""

newline = b'\n'
error_newline = b'hi there'

class ErrorMarkers(enum.IntEnum):
NONE = 0
AFTER_HEADER = enum.auto()
CTX_MARKER_POS = enum.auto()
OBS_MARKER_POS = enum.auto()
OUTCOME_MARKER_POS = enum.auto()
TENSOR_BUF_POS = enum.auto()
TENSORS_POS = enum.auto()
OUTCOME_POS = enum.auto()

def __init__(
self,
*,
opened_file: BinaryIO,
introduce_error_pos: ErrorMarkers = ErrorMarkers.NONE,
):
self._opened_file = opened_file
self._introduce_error_pos = introduce_error_pos

def write_buff(self, buffer: list, ct):
# we should get the ctypes array to bytes for pytype to be happy.
if self._introduce_error_pos == self.ErrorMarkers.TENSOR_BUF_POS:
buffer = buffer[len(buffer) // 2:]
# pytype:disable=wrong-arg-types
self._opened_file.write((ct * len(buffer))(*buffer))
# pytype:enable=wrong-arg-types

def write_newline(self, position=None):
self._opened_file.write(self.error_newline if position ==
self._introduce_error_pos else self.newline)

def write_context_marker(self, name: str):
self._opened_file.write(json_to_bytes({'context': name}))
self.write_newline(self.ErrorMarkers.CTX_MARKER_POS)

def write_observation_marker(self, obs_idx: int):
self._opened_file.write(json_to_bytes({'observation': obs_idx}))
self.write_newline(self.ErrorMarkers.OBS_MARKER_POS)

def write_outcome_marker(self, obs_idx: int):
self._opened_file.write(json_to_bytes({'outcome': obs_idx}))
self.write_newline(self.ErrorMarkers.OUTCOME_MARKER_POS)

def write_header(self, json_header: dict):
self._opened_file.write(json_to_bytes(json_header))


def create_example(fname: str,
*,
nr_contexts=1,
introduce_errors_pos: LogTestExampleBuilder
.ErrorMarkers = LogTestExampleBuilder.ErrorMarkers.NONE):
t0_val = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
t1_val = [1, 2, 3]
s = [1.2]

with open(fname, 'wb') as f:
f.write(
json_to_bytes({
'features': [{
'name': 'tensor_name2',
'port': 0,
'shape': [2, 3],
'type': 'float',
}, {
'name': 'tensor_name1',
'port': 0,
'shape': [3, 1],
'type': 'int64_t',
}],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float'
}
}))
write_nl(f)
example_writer = LogTestExampleBuilder(
opened_file=f, introduce_error_pos=introduce_errors_pos)
example_writer.write_header({
'features': [{
'name': 'tensor_name2',
'port': 0,
'shape': [2, 3],
'type': 'float',
}, {
'name': 'tensor_name1',
'port': 0,
'shape': [3, 1],
'type': 'int64_t',
}],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float'
}
})
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.AFTER_HEADER)
for ctx_id in range(nr_contexts):
t0_val = [v + ctx_id * 10 for v in t0_val]
t1_val = [v + ctx_id * 10 for v in t1_val]
write_context_marker(f, f'context_nr_{ctx_id}')
write_observation_marker(f, 0)
write_buff(f, t0_val, ctypes.c_float)
write_buff(f, t1_val, ctypes.c_int64)
write_nl(f)
write_outcome_marker(f, 0)
write_buff(f, s, ctypes.c_float)
write_nl(f)

example_writer.write_context_marker(f'context_nr_{ctx_id}')

def write_example_obs(obs: int):
example_writer.write_observation_marker(obs)
example_writer.write_buff(t0_val, ctypes.c_float)
example_writer.write_buff(t1_val, ctypes.c_int64)
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.TENSORS_POS)
example_writer.write_outcome_marker(obs)
example_writer.write_buff(s, ctypes.c_float)
example_writer.write_newline(
LogTestExampleBuilder.ErrorMarkers.OUTCOME_POS)

write_example_obs(0)
t0_val = [v + 1 for v in t0_val]
t1_val = [v + 1 for v in t1_val]
s[0] += 1

write_observation_marker(f, 1)
write_buff(f, t0_val, ctypes.c_float)
write_buff(f, t1_val, ctypes.c_int64)
write_nl(f)
write_outcome_marker(f, 1)
write_buff(f, s, ctypes.c_float)
write_nl(f)
write_example_obs(1)


class LogReaderTest(tf.test.TestCase):
Expand Down Expand Up @@ -246,6 +277,41 @@ def test_seq_example_conversion(self):
""", tf.train.SequenceExample())
self.assertProtoEquals(expected_ctx_0, seq_examples['context_nr_0'])

def test_errors(self):
logfile = self.create_tempfile()
for error_marker in LogTestExampleBuilder.ErrorMarkers:
if not error_marker:
continue
create_example(logfile, introduce_errors_pos=error_marker)
with self.assertRaises(Exception):
log_reader.read_log_as_sequence_examples(logfile)

def test_truncated_tensors(self):
logfile = self.create_tempfile()
with open(logfile, 'wb') as f:
writer = LogTestExampleBuilder(opened_file=f)
writer.write_header({
'features': [{
'name': 'tensor_name',
'port': 0,
'shape': [2, 3],
'type': 'float',
}],
'score': {
'name': 'reward',
'port': 0,
'shape': [1],
'type': 'float'
}
})
writer.write_newline()
writer.write_context_marker('whatever')
writer.write_observation_marker(0)
writer.write_buff([1], ctypes.c_int16)

with self.assertRaises(Exception):
log_reader.read_log_as_sequence_examples(logfile)


if __name__ == '__main__':
tf.test.main()

0 comments on commit b8aaa54

Please sign in to comment.