-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[py] add log tools (gz_log & split log)
- Loading branch information
Showing
15 changed files
with
1,316 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# log tools & dataloader | ||
* store xxx.log.gz in `__log/` dir | ||
```bash | ||
python tools/logsplit.py __log/xxx.log.gz # generate split log | ||
python tools/logread.py __log/xxx.log/ # check log | ||
python data/tracker_vision.py __log/ # load all log in __log dir | ||
``` |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from torch.utils.data import DataLoader | ||
|
||
def data_loader(args, path): | ||
dset = TrackerVisionDataset( | ||
path, | ||
obs_len=args.obs_len, | ||
pred_len=args.pred_len, | ||
skip=args.skip, | ||
delim=args.delim) | ||
|
||
loader = DataLoader( | ||
dset, | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
num_workers=args.loader_num_workers, | ||
collate_fn=seq_collate) | ||
return dset, loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import sys, os, copy | ||
from torch.utils.data import Dataset | ||
import numpy as np | ||
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket | ||
from tzcp.ssl.vision.messages_robocup_ssl_detection_tracked_pb2 import TrackedFrame, TeamColor | ||
sys.path.append('../../..') | ||
from rocos.log.tools.logread import get_content, read_log, MSG_INDEX as I | ||
from rocos.log.tools.logdefine import TYPE | ||
|
||
LOG_MIN_LEN = 100 | ||
|
||
class TrackerVisionDataset(Dataset): | ||
def __init__(self, data_dir, obs_len = 8, pred_len = 12, skip = 5): | ||
super(TrackerVisionDataset, self).__init__() | ||
self.data_dir = data_dir | ||
self.obs_len = obs_len | ||
self.pred_len = pred_len | ||
self.skip = skip | ||
self.seq_len = self.obs_len + self.pred_len | ||
|
||
self.data = [] | ||
|
||
all_log_files = [] | ||
# get log files | ||
for root, dirs, files in os.walk(self.data_dir): | ||
for file in files: | ||
all_log_files.append(os.path.join(root, file)) | ||
|
||
|
||
for log_file in all_log_files: | ||
content = get_content(log_file) | ||
msgs, type = read_log(content) | ||
assert type == TYPE.SSL_VISION_TRACKER_2020, "Unknown message type" | ||
data_seqs = self.generate_seq(msgs, TrackerWrapperPacket) | ||
self.data.extend(data_seqs) | ||
|
||
print(f"Data Dir : {data_dir}, Total search {len(all_log_files)} log files, found {len(self.data)} sequences") | ||
|
||
def generate_seq(self, msgs, MsgType): | ||
if len(msgs) < LOG_MIN_LEN: | ||
return [] | ||
|
||
dataset_seqs = [] | ||
|
||
data_seq = [] | ||
for data in msgs: | ||
msg = MsgType() | ||
msg.ParseFromString(data[I.MSG]) | ||
data = self.parse_single_msg(msg) | ||
if data is None: | ||
print("parse error in msg, maybe ball not detected or robot number not correct. skip") | ||
break | ||
data_seq.append(data) | ||
|
||
for i in range(0, len(data_seq) - self.seq_len, self.skip): | ||
dataset_seqs.extend(self.generate_single_seq(data_seq[i:i+self.seq_len])) | ||
|
||
return dataset_seqs | ||
|
||
def parse_single_msg(self, msg: TrackerWrapperPacket): | ||
frame = msg.tracked_frame | ||
if len(frame.balls) == 0: | ||
return None | ||
ball = frame.balls[0] | ||
robot_blue = {} | ||
robot_yellow = {} | ||
|
||
for r in frame.robots: | ||
# robot_data | ||
rd = np.array([r.pos.x, r.pos.y, r.orientation, r.vel.x, r.vel.y, r.vel_angular, r.visibility]) | ||
if r.robot_id.team_color == TeamColor.TEAM_COLOR_BLUE: | ||
robot_blue[r.robot_id.id] = rd | ||
elif r.robot_id.team_color == TeamColor.TEAM_COLOR_YELLOW: | ||
robot_yellow[r.robot_id.id] = rd | ||
return { | ||
"ball": np.array([ball.pos.x, ball.pos.y, ball.pos.z, ball.vel.x, ball.vel.y, ball.vel.z, ball.visibility]), | ||
"blue": robot_blue, | ||
"yellow": robot_yellow | ||
} | ||
|
||
def generate_single_seq(self, seqs): | ||
if len(seqs) == 0: | ||
return [] | ||
blue_id = list(seqs[0]["blue"].keys()) | ||
yellow_id = list(seqs[0]["yellow"].keys()) | ||
checked_blue_id = copy.deepcopy(blue_id) | ||
checked_yellow_id = copy.deepcopy(yellow_id) | ||
for seq in seqs: | ||
checked_blue_id = list(set(checked_blue_id) & set(seq["blue"].keys())) | ||
checked_yellow_id = list(set(checked_yellow_id) & set(seq["yellow"].keys())) | ||
assert len(checked_blue_id) == len(blue_id) and len(checked_yellow_id) == len(yellow_id), "robot number not correct" | ||
|
||
_ball_seq = np.empty((0, seqs[0]["ball"].shape[0])) | ||
_blue_seq = np.empty((0, len(blue_id), seqs[0]["blue"][blue_id[0]].shape[0])) | ||
_yellow_seq = np.empty((0, len(yellow_id), seqs[0]["yellow"][yellow_id[0]].shape[0])) | ||
|
||
for seq in seqs: | ||
ball = seq["ball"] | ||
blue = np.array([seq["blue"][i] for i in blue_id]) | ||
yellow = np.array([seq["yellow"][i] for i in yellow_id]) | ||
_ball_seq = np.vstack((_ball_seq, ball)) | ||
_blue_seq = np.vstack((_blue_seq, blue[None, :, :])) | ||
_yellow_seq = np.vstack((_yellow_seq, yellow[None, :, :])) | ||
|
||
return [{ | ||
"ball": _ball_seq, | ||
"blue": _blue_seq, | ||
"yellow": _yellow_seq | ||
}] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def __getitem__(self, idx): | ||
return self.data[idx] | ||
|
||
if __name__ == "__main__": | ||
data_dir = sys.argv[1] | ||
dataset = TrackerVisionDataset(data_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
__pycache__/ | ||
__temp__* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
class TYPE: | ||
BLANK = 0 | ||
UNKNOWN = 1 | ||
SSL_VISION_2010 = 2 | ||
SSL_REFBOX_2013 = 3 | ||
SSL_VISION_2014 = 4 | ||
SSL_VISION_TRACKER_2020 = 5 | ||
SSL_INDEX_2021 = 6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os, sys, struct, gzip | ||
|
||
sys.path.append('./proto_gen') | ||
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket | ||
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_pb2 import SSL_WrapperPacket | ||
sys.path.append('../../..') | ||
from rocos.log.tools.logdefine import TYPE | ||
def get_content(filename): | ||
content = None | ||
if filename.endswith('.gz'): | ||
with gzip.open(filename, 'rb') as f: | ||
content = f.read() | ||
else: | ||
with open(filename, 'rb') as f: | ||
content = f.read() | ||
return content | ||
|
||
class MSG_INDEX: | ||
TIMESTAMP = 0 | ||
TYPE = 1 | ||
SIZE = 2 | ||
MSG = 3 | ||
|
||
def read_log(content): | ||
if content[:12] != b'TZ_SPLIT_LOG': | ||
print('Not a valid log file') | ||
return None | ||
msgs = [] | ||
data = content[12:] | ||
data_index = 0 | ||
while data_index < len(data): | ||
timestamp, type, size = struct.unpack('>qii', data[data_index:data_index+16]) | ||
data_index += 16 | ||
msg = data[data_index:data_index+size] | ||
data_index += size | ||
msgs.append((timestamp, type, size, msg)) | ||
return msgs, type | ||
|
||
def check_log(filename): | ||
content = get_content(filename) | ||
msgs, type = read_log(content) | ||
|
||
MsgType = None | ||
if type == TYPE.SSL_VISION_TRACKER_2020: | ||
MsgType = TrackerWrapperPacket | ||
elif type == TYPE.SSL_VISION_2014: | ||
MsgType = SSL_WrapperPacket | ||
else: | ||
print('Unknown message type') | ||
sys.exit(1) | ||
for msg in msgs: | ||
pack = MsgType() | ||
pack.ParseFromString(msg[3]) | ||
|
||
if type == TYPE.SSL_VISION_TRACKER_2020: | ||
frame_number = pack.tracked_frame.frame_number | ||
ball_size = pack.tracked_frame.balls.__len__() | ||
robot_size = pack.tracked_frame.robots.__len__() | ||
if ball_size != 1 or robot_size != 22: | ||
print(filename , " - frame_num: ", frame_number, 'ball size:', ball_size, 'robot size:', robot_size) | ||
# delete file | ||
os.remove(filename) | ||
return | ||
print(filename, ' - OK') | ||
if __name__ == '__main__': | ||
dir_name = sys.argv[1] | ||
all_log_files = [] | ||
# get all file in dir_name with walk | ||
for root, dirs, files in os.walk(dir_name): | ||
for file in files: | ||
all_log_files.append(os.path.join(root, file)) | ||
|
||
for log_file in all_log_files: | ||
check_log(log_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import os, sys, struct | ||
from typing import Optional | ||
import gzip | ||
import numpy as np | ||
|
||
# File Format | ||
# Each log file starts with the following header: | ||
|
||
# 1: String – File type (“SSL_LOG_FILE”) | ||
# 2: Int32 – Log file format version | ||
|
||
# Format version 1 encodes the protobuf messages in the following format: | ||
|
||
# 1: Int64 – Receiver timestamp in ns | ||
# 2: Int32 – Message type | ||
# 3: Int32 – Size of binary protobuf message | ||
# 4: String – Binary protobuf message | ||
|
||
# The message types are: | ||
|
||
# MESSAGE_BLANK = 0 (ignore message) | ||
# MESSAGE_UNKNOWN = 1 (try to guess message type by parsing the data) | ||
# MESSAGE_SSL_VISION_2010 = 2 | ||
# MESSAGE_SSL_REFBOX_2013 = 3 | ||
# MESSAGE_SSL_VISION_2014 = 4 | ||
# MESSAGE_SSL_VISION_TRACKER_2020 = 5 | ||
# MESSAGE_SSL_INDEX_2021 = 6 | ||
|
||
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_tracked_pb2 import TrackerWrapperPacket | ||
from tzcp.ssl.vision.messages_robocup_ssl_wrapper_pb2 import SSL_WrapperPacket | ||
from tzcp.ssl.ref.ssl_referee_pb2 import Referee | ||
sys.path.append('../../..') | ||
from rocos.log.tools.logdefine import TYPE | ||
class SplitLog: | ||
def __init__(self, name, compress=False): | ||
self.filename = name+".splitlog" + (".gz" if compress else "") | ||
# get path prefix | ||
path = os.path.dirname(self.filename) | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
self.file = gzip.open(self.filename, 'wb') if compress else open(self.filename, 'wb') | ||
self.file.write(b'TZ_SPLIT_LOG') | ||
def write(self,timestamp,type,size,data): | ||
self.file.write(struct.pack('>qii',timestamp,type,size)) | ||
self.file.write(data) | ||
def close(self): | ||
self.file.close() | ||
def __del__(self): | ||
self.close() | ||
class LogSplitter: | ||
class Config: | ||
def __init__(self): | ||
self.skip_not_running_stages = True | ||
self.tracker_source_name_filter = ['TIGERs'] | ||
self.record_ref_commands = [ | ||
Referee.Command.NORMAL_START, | ||
Referee.Command.FORCE_START, | ||
Referee.Command.DIRECT_FREE_YELLOW, | ||
Referee.Command.DIRECT_FREE_BLUE, | ||
] | ||
def __init__(self, config: Config = Config()): | ||
self.config = config | ||
self.counter = np.zeros(7, dtype=int) | ||
self.current_stage = None | ||
self.current_ref_command = None | ||
self.current_vision_tracker = None | ||
self.current_vision = None | ||
# self.split_vision = None | ||
self.split_tracker = None | ||
def new_log(self, filename): | ||
# if self.split_vision: | ||
# self.split_vision.close() | ||
if self.split_tracker: | ||
self.split_tracker.close() | ||
# self.split_vision = SplitLog(filename + '_vision') | ||
self.split_tracker = SplitLog(filename + '_tracker') | ||
def parse_msg(self, type, data, timestamp, size): | ||
msg = None | ||
if type == TYPE.SSL_REFBOX_2013: | ||
msg = Referee() | ||
msg.ParseFromString(data) | ||
self.current_ref_command = msg.command | ||
self.current_stage = msg.stage | ||
pass | ||
elif type == TYPE.SSL_VISION_2014: | ||
msg = SSL_WrapperPacket() | ||
msg.ParseFromString(data) | ||
self.current_vision = msg | ||
self.counter[type] += 1 | ||
# self.split_vision.write(timestamp, type, size, data) | ||
pass | ||
elif type == TYPE.SSL_VISION_TRACKER_2020: | ||
msg = TrackerWrapperPacket() | ||
msg.ParseFromString(data) | ||
if msg.source_name in self.config.tracker_source_name_filter: | ||
self.counter[type] += 1 | ||
self.current_vision_tracker = msg | ||
self.split_tracker.write(timestamp, type, size, data) | ||
pass | ||
else: | ||
# MESSAGE_BLANK | ||
# MESSAGE_UNKNOWN | ||
# SSL_VISION_2010 | ||
# SSL_INDEX_2021 | ||
# 'Unknown message type' | ||
return | ||
def split(self,filename,store_prefix=None): | ||
if store_prefix is None: | ||
store_prefix = os.path.splitext(filename)[0] | ||
if not os.path.exists(store_prefix): | ||
os.makedirs(store_prefix) | ||
with gzip.open(filename, 'rb') as f: | ||
content = f.read() | ||
# get header | ||
header, msgs = content[:16], content[16:] | ||
# check header | ||
if header[:12] != b'SSL_LOG_FILE': | ||
print('Not a valid log file') | ||
return | ||
version = struct.unpack('>I', header[12:])[0] | ||
# read messages | ||
msg_index = 0 | ||
while msg_index < len(msgs): | ||
timestamp, msg_type, msg_size = struct.unpack('>qii', msgs[msg_index:msg_index+16]) | ||
msg_index += 16 | ||
if msg_type == 0: | ||
continue | ||
msg = msgs[msg_index:msg_index+msg_size] | ||
msg_index = msg_index + msg_size | ||
last_ref_command = self.current_ref_command | ||
if self.current_ref_command not in self.config.record_ref_commands and msg_type != TYPE.SSL_REFBOX_2013: | ||
continue | ||
self.parse_msg(msg_type, msg, timestamp, msg_size) | ||
if self.current_ref_command != last_ref_command and self.current_ref_command in self.config.record_ref_commands: | ||
if last_ref_command is not None: | ||
print(Referee.Command.Name(last_ref_command), '->', Referee.Command.Name(self.current_ref_command)) | ||
path_prefix = os.path.join(store_prefix, str(self.current_stage) + '-' + Referee.Stage.Name(self.current_stage) + '/' +str(timestamp) + '-' + Referee.Command.Name(self.current_ref_command)) | ||
self.new_log(path_prefix) | ||
|
||
if __name__ == '__main__': | ||
splitter = LogSplitter() | ||
# set name as arg[1] | ||
log_name = sys.argv[1] | ||
splitter.split(log_name) |
Oops, something went wrong.