Skip to content

Commit

Permalink
[py] add log tools (gz_log & split log)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark-tz committed Dec 2, 2024
1 parent 87091d8 commit 60ff84f
Show file tree
Hide file tree
Showing 15 changed files with 1,316 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ZBin/py_playground/rocos/log/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# log tools & dataloader
* store xxx.log.gz in `__log/` dir
```bash
python tools/logsplit.py __log/xxx.log.gz # generate split log
python tools/logread.py __log/xxx.log/ # check log
python data/tracker_vision.py __log/ # load all log in __log dir
```
Empty file.
Empty file.
17 changes: 17 additions & 0 deletions ZBin/py_playground/rocos/log/data/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch.utils.data import DataLoader

def data_loader(args, path):
dset = TrackerVisionDataset(
path,
obs_len=args.obs_len,
pred_len=args.pred_len,
skip=args.skip,
delim=args.delim)

loader = DataLoader(
dset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.loader_num_workers,
collate_fn=seq_collate)
return dset, loader
119 changes: 119 additions & 0 deletions ZBin/py_playground/rocos/log/data/tracker_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import sys, os, copy
from torch.utils.data import Dataset
import numpy as np
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket
from tzcp.ssl.vision.messages_robocup_ssl_detection_tracked_pb2 import TrackedFrame, TeamColor
sys.path.append('../../..')
from rocos.log.tools.logread import get_content, read_log, MSG_INDEX as I
from rocos.log.tools.logdefine import TYPE

LOG_MIN_LEN = 100

class TrackerVisionDataset(Dataset):
def __init__(self, data_dir, obs_len = 8, pred_len = 12, skip = 5):
super(TrackerVisionDataset, self).__init__()
self.data_dir = data_dir
self.obs_len = obs_len
self.pred_len = pred_len
self.skip = skip
self.seq_len = self.obs_len + self.pred_len

self.data = []

all_log_files = []
# get log files
for root, dirs, files in os.walk(self.data_dir):
for file in files:
all_log_files.append(os.path.join(root, file))


for log_file in all_log_files:
content = get_content(log_file)
msgs, type = read_log(content)
assert type == TYPE.SSL_VISION_TRACKER_2020, "Unknown message type"
data_seqs = self.generate_seq(msgs, TrackerWrapperPacket)
self.data.extend(data_seqs)

print(f"Data Dir : {data_dir}, Total search {len(all_log_files)} log files, found {len(self.data)} sequences")

def generate_seq(self, msgs, MsgType):
if len(msgs) < LOG_MIN_LEN:
return []

dataset_seqs = []

data_seq = []
for data in msgs:
msg = MsgType()
msg.ParseFromString(data[I.MSG])
data = self.parse_single_msg(msg)
if data is None:
print("parse error in msg, maybe ball not detected or robot number not correct. skip")
break
data_seq.append(data)

for i in range(0, len(data_seq) - self.seq_len, self.skip):
dataset_seqs.extend(self.generate_single_seq(data_seq[i:i+self.seq_len]))

return dataset_seqs

def parse_single_msg(self, msg: TrackerWrapperPacket):
frame = msg.tracked_frame
if len(frame.balls) == 0:
return None
ball = frame.balls[0]
robot_blue = {}
robot_yellow = {}

for r in frame.robots:
# robot_data
rd = np.array([r.pos.x, r.pos.y, r.orientation, r.vel.x, r.vel.y, r.vel_angular, r.visibility])
if r.robot_id.team_color == TeamColor.TEAM_COLOR_BLUE:
robot_blue[r.robot_id.id] = rd
elif r.robot_id.team_color == TeamColor.TEAM_COLOR_YELLOW:
robot_yellow[r.robot_id.id] = rd
return {
"ball": np.array([ball.pos.x, ball.pos.y, ball.pos.z, ball.vel.x, ball.vel.y, ball.vel.z, ball.visibility]),
"blue": robot_blue,
"yellow": robot_yellow
}

def generate_single_seq(self, seqs):
if len(seqs) == 0:
return []
blue_id = list(seqs[0]["blue"].keys())
yellow_id = list(seqs[0]["yellow"].keys())
checked_blue_id = copy.deepcopy(blue_id)
checked_yellow_id = copy.deepcopy(yellow_id)
for seq in seqs:
checked_blue_id = list(set(checked_blue_id) & set(seq["blue"].keys()))
checked_yellow_id = list(set(checked_yellow_id) & set(seq["yellow"].keys()))
assert len(checked_blue_id) == len(blue_id) and len(checked_yellow_id) == len(yellow_id), "robot number not correct"

_ball_seq = np.empty((0, seqs[0]["ball"].shape[0]))
_blue_seq = np.empty((0, len(blue_id), seqs[0]["blue"][blue_id[0]].shape[0]))
_yellow_seq = np.empty((0, len(yellow_id), seqs[0]["yellow"][yellow_id[0]].shape[0]))

for seq in seqs:
ball = seq["ball"]
blue = np.array([seq["blue"][i] for i in blue_id])
yellow = np.array([seq["yellow"][i] for i in yellow_id])
_ball_seq = np.vstack((_ball_seq, ball))
_blue_seq = np.vstack((_blue_seq, blue[None, :, :]))
_yellow_seq = np.vstack((_yellow_seq, yellow[None, :, :]))

return [{
"ball": _ball_seq,
"blue": _blue_seq,
"yellow": _yellow_seq
}]

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
return self.data[idx]

if __name__ == "__main__":
data_dir = sys.argv[1]
dataset = TrackerVisionDataset(data_dir)
2 changes: 2 additions & 0 deletions ZBin/py_playground/rocos/log/tools/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__pycache__/
__temp__*
8 changes: 8 additions & 0 deletions ZBin/py_playground/rocos/log/tools/logdefine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
class TYPE:
BLANK = 0
UNKNOWN = 1
SSL_VISION_2010 = 2
SSL_REFBOX_2013 = 3
SSL_VISION_2014 = 4
SSL_VISION_TRACKER_2020 = 5
SSL_INDEX_2021 = 6
74 changes: 74 additions & 0 deletions ZBin/py_playground/rocos/log/tools/logread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os, sys, struct, gzip

sys.path.append('./proto_gen')
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_pb2 import SSL_WrapperPacket
sys.path.append('../../..')
from rocos.log.tools.logdefine import TYPE
def get_content(filename):
content = None
if filename.endswith('.gz'):
with gzip.open(filename, 'rb') as f:
content = f.read()
else:
with open(filename, 'rb') as f:
content = f.read()
return content

class MSG_INDEX:
TIMESTAMP = 0
TYPE = 1
SIZE = 2
MSG = 3

def read_log(content):
if content[:12] != b'TZ_SPLIT_LOG':
print('Not a valid log file')
return None
msgs = []
data = content[12:]
data_index = 0
while data_index < len(data):
timestamp, type, size = struct.unpack('>qii', data[data_index:data_index+16])
data_index += 16
msg = data[data_index:data_index+size]
data_index += size
msgs.append((timestamp, type, size, msg))
return msgs, type

def check_log(filename):
content = get_content(filename)
msgs, type = read_log(content)

MsgType = None
if type == TYPE.SSL_VISION_TRACKER_2020:
MsgType = TrackerWrapperPacket
elif type == TYPE.SSL_VISION_2014:
MsgType = SSL_WrapperPacket
else:
print('Unknown message type')
sys.exit(1)
for msg in msgs:
pack = MsgType()
pack.ParseFromString(msg[3])

if type == TYPE.SSL_VISION_TRACKER_2020:
frame_number = pack.tracked_frame.frame_number
ball_size = pack.tracked_frame.balls.__len__()
robot_size = pack.tracked_frame.robots.__len__()
if ball_size != 1 or robot_size != 22:
print(filename , " - frame_num: ", frame_number, 'ball size:', ball_size, 'robot size:', robot_size)
# delete file
os.remove(filename)
return
print(filename, ' - OK')
if __name__ == '__main__':
dir_name = sys.argv[1]
all_log_files = []
# get all file in dir_name with walk
for root, dirs, files in os.walk(dir_name):
for file in files:
all_log_files.append(os.path.join(root, file))

for log_file in all_log_files:
check_log(log_file)
144 changes: 144 additions & 0 deletions ZBin/py_playground/rocos/log/tools/logsplit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os, sys, struct
from typing import Optional
import gzip
import numpy as np

# File Format
# Each log file starts with the following header:

# 1: String – File type (“SSL_LOG_FILE”)
# 2: Int32 – Log file format version

# Format version 1 encodes the protobuf messages in the following format:

# 1: Int64 – Receiver timestamp in ns
# 2: Int32 – Message type
# 3: Int32 – Size of binary protobuf message
# 4: String – Binary protobuf message

# The message types are:

# MESSAGE_BLANK = 0 (ignore message)
# MESSAGE_UNKNOWN = 1 (try to guess message type by parsing the data)
# MESSAGE_SSL_VISION_2010 = 2
# MESSAGE_SSL_REFBOX_2013 = 3
# MESSAGE_SSL_VISION_2014 = 4
# MESSAGE_SSL_VISION_TRACKER_2020 = 5
# MESSAGE_SSL_INDEX_2021 = 6

from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_pb2 import SSL_WrapperPacket
from tzcp.ssl.ref.ssl_referee_pb2 import Referee
sys.path.append('../../..')
from rocos.log.tools.logdefine import TYPE
class SplitLog:
def __init__(self, name, compress=False):
self.filename = name+".splitlog" + (".gz" if compress else "")
# get path prefix
path = os.path.dirname(self.filename)
if not os.path.exists(path):
os.makedirs(path)
self.file = gzip.open(self.filename, 'wb') if compress else open(self.filename, 'wb')
self.file.write(b'TZ_SPLIT_LOG')
def write(self,timestamp,type,size,data):
self.file.write(struct.pack('>qii',timestamp,type,size))
self.file.write(data)
def close(self):
self.file.close()
def __del__(self):
self.close()
class LogSplitter:
class Config:
def __init__(self):
self.skip_not_running_stages = True
self.tracker_source_name_filter = ['TIGERs']
self.record_ref_commands = [
Referee.Command.NORMAL_START,
Referee.Command.FORCE_START,
Referee.Command.DIRECT_FREE_YELLOW,
Referee.Command.DIRECT_FREE_BLUE,
]
def __init__(self, config: Config = Config()):
self.config = config
self.counter = np.zeros(7, dtype=int)
self.current_stage = None
self.current_ref_command = None
self.current_vision_tracker = None
self.current_vision = None
# self.split_vision = None
self.split_tracker = None
def new_log(self, filename):
# if self.split_vision:
# self.split_vision.close()
if self.split_tracker:
self.split_tracker.close()
# self.split_vision = SplitLog(filename + '_vision')
self.split_tracker = SplitLog(filename + '_tracker')
def parse_msg(self, type, data, timestamp, size):
msg = None
if type == TYPE.SSL_REFBOX_2013:
msg = Referee()
msg.ParseFromString(data)
self.current_ref_command = msg.command
self.current_stage = msg.stage
pass
elif type == TYPE.SSL_VISION_2014:
msg = SSL_WrapperPacket()
msg.ParseFromString(data)
self.current_vision = msg
self.counter[type] += 1
# self.split_vision.write(timestamp, type, size, data)
pass
elif type == TYPE.SSL_VISION_TRACKER_2020:
msg = TrackerWrapperPacket()
msg.ParseFromString(data)
if msg.source_name in self.config.tracker_source_name_filter:
self.counter[type] += 1
self.current_vision_tracker = msg
self.split_tracker.write(timestamp, type, size, data)
pass
else:
# MESSAGE_BLANK
# MESSAGE_UNKNOWN
# SSL_VISION_2010
# SSL_INDEX_2021
# 'Unknown message type'
return
def split(self,filename,store_prefix=None):
if store_prefix is None:
store_prefix = os.path.splitext(filename)[0]
if not os.path.exists(store_prefix):
os.makedirs(store_prefix)
with gzip.open(filename, 'rb') as f:
content = f.read()
# get header
header, msgs = content[:16], content[16:]
# check header
if header[:12] != b'SSL_LOG_FILE':
print('Not a valid log file')
return
version = struct.unpack('>I', header[12:])[0]
# read messages
msg_index = 0
while msg_index < len(msgs):
timestamp, msg_type, msg_size = struct.unpack('>qii', msgs[msg_index:msg_index+16])
msg_index += 16
if msg_type == 0:
continue
msg = msgs[msg_index:msg_index+msg_size]
msg_index = msg_index + msg_size
last_ref_command = self.current_ref_command
if self.current_ref_command not in self.config.record_ref_commands and msg_type != TYPE.SSL_REFBOX_2013:
continue
self.parse_msg(msg_type, msg, timestamp, msg_size)
if self.current_ref_command != last_ref_command and self.current_ref_command in self.config.record_ref_commands:
if last_ref_command is not None:
print(Referee.Command.Name(last_ref_command), '->', Referee.Command.Name(self.current_ref_command))
path_prefix = os.path.join(store_prefix, str(self.current_stage) + '-' + Referee.Stage.Name(self.current_stage) + '/' +str(timestamp) + '-' + Referee.Command.Name(self.current_ref_command))
self.new_log(path_prefix)

if __name__ == '__main__':
splitter = LogSplitter()
# set name as arg[1]
log_name = sys.argv[1]
splitter.split(log_name)
Loading

0 comments on commit 60ff84f

Please sign in to comment.