diff --git a/framework/data_generator/controls/__init__.py b/framework/data_generator/controls/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/framework/data_generator/controls/control_generator.py b/framework/data_generator/controls/control_generator.py new file mode 100644 index 0000000..b107107 --- /dev/null +++ b/framework/data_generator/controls/control_generator.py @@ -0,0 +1,113 @@ +import os +import sys +path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.insert(0, path) + +import numpy as np +from carball.controls.controls import ControlsCreator +from carball.json_parser.game import Game +from carball.json_parser.player import Player +import carball +from data_generator.base_generator import BaseDataGenerator +from rlbot.utils.logging_utils import get_logger +from typing import List + + +class ReplayControlGen(BaseDataGenerator): + """ + Generates Controler data from a set of local replays. + Use method get_data() for Iterable + """ + + def __init__(self, filepaths: List[str]): + """ + Parameters + ---------- + filepaths : List[str] + A list of paths to replay files, exclude the .replay file extensions + """ + self.files = filepaths + self.file = 0 + self.game = gameFromFile(self.files[self.file]) + self.maxFrame = len(self.game.frames.time) + self.currentFrame = 1 + self.player = 0 + + def initialize(self, **kwargs): + pass + + def has_next(self): + """Returns false when we are on the last player in the last file and have no more frames""" + if self.currentFrame >= self.maxFrame and self.player >= len( self.game.players ) - 1 and self.file >= len( self.files ) - 1: + return False + return True + + def _next(self) -> List[float]: + """ + :return: list, [ Throttle, Steer, pitch, yaw, roll, jump, boost, handbrake ] + """ + if self.currentFrame >= self.maxFrame: + self.player += 1 + self.currentFrame = 1 + if self.player >= len(self.game.players) and self.file < len(self.files): + self.file += 1 + self.game = gameFromFile(self.files[self.file]) + self.maxFrame = len(self.game.frames.time) + self.player = 0 + self.currentFrame = 1 + player = self.game.players[self.player] + c = getControls(player, self.currentFrame) + self.currentFrame += 1 + return c + + +def gameFromFile(replaypath: str) -> Game: + """ + Instantiats a game object from a replay and creates controls data. + Creates json file representaion of replay. + Parameters + ---------- + replaypath: str + Filepath to a replay, exclude .replay extension. + :return: Game + """ + _json = carball.decompile_replay(replaypath + '.replay', + replaypath + '.json', + overwrite=True) + + game = Game() + game.initialize(loaded_json=_json) + + ControlsCreator().get_controls(game) + return game + + +def getControls(player: Player, frame: int) -> List[float]: + """ + Gets the controls of a player at a frame + Parameters + ---------- + player : Player + game.players[p] + frame : int + Frame to get the controls from + :return: List[float] + List of format [ Throttle, Steer, pitch, yaw, roll, jump, boost, handbrake ] + """ + c = player.controls + throttle = c.throttle.get(frame) + steer = c.steer.get(frame) + pitch = c.pitch.get(frame) + yaw = c.yaw.get(frame) + roll = c.roll.get(frame) + jump = c.jump.get(frame) + boost = c.boost.get(frame) + handbrake = c.handbrake.get(frame) + controls = [throttle, steer, pitch, yaw, roll, jump, boost, handbrake] + for c in range(len(controls)): + control = controls[c] + if control == False or control == None or np.isnan(control): + controls[c] = 0 + elif control == True: + controls[c] = 1 + return controls