Skip to content

Commit

Permalink
Refactor state transition logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ckaraneen committed Sep 19, 2020
1 parent d0358ee commit e5a8fdb
Showing 1 changed file with 64 additions and 97 deletions.
161 changes: 64 additions & 97 deletions pybpodapi/bpod/bpod_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,46 @@ def __initialize_input_command_handler(self):
self.stdin = NonBlockingStreamReader(
sys.stdin) if settings.PYBPOD_API_ACCEPT_STDIN else None

def __transition_to_new_state(self, sma, event_id, transition_matrix,
current_trial, state_change_indexes,
is_state_timer_matrix=False,
debug_message=None):
new_state_set = False

def set_sma_current_state(new_state):
previous_state = sma.current_state
if sma.use_255_back_signal and new_state == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = new_state
logger.debug(('Transition occured: '
f'state {previous_state} -> {sma.current_state}'))
if not math.isnan(sma.current_state):
if debug_message is not None:
logger.debug(debug_message)
current_trial.states.append(sma.current_state)
state_change_indexes.append(
len(current_trial.events_occurrences) - 1)
current_state = sma.current_state
if is_state_timer_matrix:
this_state_timer_state = transition_matrix[current_state]
is_event_id_tup = event_id == sma.hardware.channels.events_positions.Tup
if is_event_id_tup and this_state_timer_state != current_state:
set_sma_current_state(this_state_timer_state)
new_state_set = True
else:
for transition_event_code, transition_state in transition_matrix[
current_state]:
if transition_event_code == event_id:
set_sma_current_state(transition_state)
new_state_set = True
else:
logger.debug((f'Event {transition_event_code} required '
f'for transition: state '
f'{sma.current_state} -> '
f'{transition_state}'))
return new_state_set

def __process_opcode(self, sma, opcode, data, state_change_indexes):
"""
Process data from bpod board given an opcode
Expand Down Expand Up @@ -538,103 +578,30 @@ def __process_opcode(self, sma, opcode, data, state_change_indexes):
)
self.trial_timestamps.append(event_timestamp)

# input matrix
if not transition_event_found:
logger.debug("transition event not found")
logger.debug("Current state: %s", sma.current_state)
for transition in sma.input_matrix[sma.current_state]:
logger.debug("Transition: %s", transition)
if transition[0] == event_id:
if sma.use_255_back_signal and transition[1] == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = transition[1]

if not math.isnan(sma.current_state):
logger.debug("adding states input matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)

transition_event_found = True

# state timer matrix
if not transition_event_found:
this_state_timer_transition = sma.state_timer_matrix[sma.current_state]
if event_id == sma.hardware.channels.events_positions.Tup:
if not (this_state_timer_transition == sma.current_state):
if sma.use_255_back_signal and this_state_timer_transition == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = this_state_timer_transition

if not math.isnan(sma.current_state):
logger.debug("adding states state timer matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
transition_event_found = True

# global timers start matrix
if not transition_event_found:
for transition in sma.global_timers.start_matrix[sma.current_state]:
if transition[0] == event_id:
if sma.use_255_back_signal and transition[1] == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = transition[1]

if not math.isnan(sma.current_state):
logger.debug("adding states global timers start matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
transition_event_found = True

# global timers end matrix
if not transition_event_found:
for transition in sma.global_timers.end_matrix[sma.current_state]:
if transition[0] == event_id:

if sma.use_255_back_signal and transition[1] == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = transition[1]

if not math.isnan(sma.current_state):
logger.debug("adding states global timers end matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
transition_event_found = True

# global counters matrix
if not transition_event_found:
for transition in sma.global_counters.matrix[sma.current_state]:
if transition[0] == event_id:

if sma.use_255_back_signal and transition[1] == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = transition[1]

if not math.isnan(sma.current_state):
logger.debug("adding states global timers end matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
transition_event_found = True

# conditions matrix
if not transition_event_found:
for transition in sma.conditions.matrix[sma.current_state]:
if transition[0] == event_id:

if sma.use_255_back_signal and transition[1] == 255:
sma.current_state = current_trial.states[-2]
else:
sma.current_state = transition[1]

if not math.isnan(sma.current_state):
logger.debug("adding states global timers end matrix")
current_trial.states.append(sma.current_state)
state_change_indexes.append(len(current_trial.events_occurrences) - 1)
transition_event_found = True
logger.debug("Current state: %s", sma.current_state)
transition_matrices = {
'input': sma.input_matrix,
'state_timer': sma.state_timer_matrix,
'global_timers_start': sma.global_timers.start_matrix,
'global_timers_end': sma.global_timers.end_matrix,
'global_counters': sma.global_counters.matrix,
'conditions': sma.conditions.matrix,
}
for transition_matrix_name, transition_matrix in \
transition_matrices.items():
transition_event_found = \
self.__transition_to_new_state(
sma,
event_id,
transition_matrix,
current_trial,
state_change_indexes,
is_state_timer_matrix=(transition_matrix_name ==
'state_timer'),
debug_message="Adding {} matrix states".format(
transition_matrix_name))
if transition_event_found:
break

logger.debug("States indexes: %s", current_trial.states)
if self._emulator is not None:
Expand Down

0 comments on commit e5a8fdb

Please sign in to comment.