-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
403 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
local runX = 0 | ||
local runY = 0 | ||
local runPos = function() | ||
return CGeoPoint(runX, runY) | ||
end | ||
local getRand = function() | ||
local r = 2*(math.random()-0.5) -- [-1,1] | ||
return r * 4000 | ||
end | ||
|
||
local getDataStr = function(v) | ||
local data = { | ||
os.clock(), | ||
v:X(),v:Y(),v:VelX(),v:VelY(), | ||
v:RawRotVel(),v:RotVel(), | ||
-- v:RawPos():x(),v:RawPos():y(), | ||
v:RawVel():x(),v:RawVel():y(), | ||
runX,runY, | ||
} | ||
local str = "" | ||
for i,value in pairs(data) do | ||
str = str .. string.format("%.3f;",value) | ||
end | ||
return str .. '\n' | ||
end | ||
|
||
local recFile = nil | ||
|
||
return { | ||
firstState = "reset", | ||
["reset"] = { | ||
switch = function() | ||
if recFile ~= nil then | ||
recFile:close() | ||
recFile = nil | ||
end | ||
runX = getRand() | ||
runY = getRand() | ||
return "randRun" | ||
end, | ||
Leader = task.stop(), | ||
match = "[L]" | ||
}, | ||
["randRun"] = { | ||
switch = function() | ||
if bufcnt(true, 100) then | ||
runX = getRand() | ||
runY = getRand() | ||
local fileName = "__data/robot_run/" .. os.date("%m%d%H%M%S") .. os.clock() | ||
recFile = io.open(fileName, 'w') | ||
recFile:write(getDataStr(player.instance("Leader"))) | ||
return "testData" | ||
end | ||
end, | ||
Leader = task.goCmuRush(runPos,0), | ||
match = "{L}" | ||
}, | ||
["testData"] = { | ||
switch = function() | ||
if player.toTargetDist("Leader") > 99999 then | ||
return "reset" | ||
end | ||
local data = getDataStr(player.instance("Leader")) | ||
debugEngine:gui_debug_msg(CGeoPoint(0,0),data) | ||
recFile:write(getDataStr(player.instance("Leader"))) | ||
if bufcnt(player.toTargetDist("Leader") < 10,10) then | ||
return "reset" | ||
end | ||
end, | ||
Leader = task.goCmuRush(runPos,0), | ||
match = "{L}" | ||
}, | ||
name = "ParamPredictTime", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
math.randomseed(os.time()) | ||
package.path = package.path .. ";./lua_scripts/?.lua" | ||
|
||
-- require("Judge") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
import numpy as np | ||
from rocos.utils.Geom import Vec | ||
|
||
# get all file with walk | ||
def get_all_files(path): | ||
all_files = [] | ||
for root, dirs, files in os.walk(path): | ||
for file in files: | ||
all_files.append(os.path.join(root, file)) | ||
return all_files | ||
|
||
def parse_file(filename): | ||
datas = [] | ||
with open(filename, 'r') as f: | ||
lines = f.readlines() | ||
for line in lines: | ||
data = line[:-2].split(';') | ||
data = [float(x) for x in data] | ||
datas.append(data) | ||
return datas | ||
|
||
|
||
# 0 : os.clock(), | ||
# 1 : v:X(), | ||
# 2 : v:Y(), | ||
# 3 : v:VelX(), | ||
# 4 : v:VelY(), | ||
# 5 : v:RawRotVel(), | ||
# 6 : v:RotVel(), | ||
# 7 : v:RawVel():x(), | ||
# 8 : v:RawVel():y(), | ||
# 9 : runX, | ||
# 10 : runY, | ||
|
||
class State: | ||
FRAME_RATE = 62.5 | ||
def __init__(self,data): | ||
self.time = data[0] | ||
self.pos = Vec(data[1], data[2]) | ||
self.vel = Vec(data[3], data[4]) | ||
self.rawVel = Vec(data[7], data[8]) | ||
self.target = Vec(data[9], data[10]) | ||
|
||
def parse_data(data, skip = 10): | ||
def get_train_data(startState, endState): | ||
time = endState.time - startState.time | ||
target = endState.pos - startState.pos | ||
vel = startState.vel | ||
rawVel = startState.rawVel | ||
# rotate | ||
rotAngle = target.dir | ||
target = target.rotate(-rotAngle) / 3000 | ||
vel = vel.rotate(-rotAngle) / 3000 | ||
rawVel = rawVel.rotate(-rotAngle) / 3000 | ||
# return [target[0], vel.x, vel.y, rawVel.x, rawVel.y], [time] | ||
# print("{}\t{}\t{}\t{}".format(target[0], rawVel.x, rawVel.y, time)) | ||
return [target[0], rawVel.x, rawVel.y], [time/5.0] | ||
frame = len(data) | ||
endData = data[-1] | ||
endData[0] = (frame-1)/State.FRAME_RATE | ||
endState = State(endData) | ||
|
||
trainDatas = [] | ||
res = [] | ||
for i in range(0, frame-1, skip): | ||
data[i][0] = i/State.FRAME_RATE | ||
state = State(data[i]) | ||
trainData, time = get_train_data(state, endState) | ||
trainDatas.append(trainData) | ||
res.append(time) | ||
return trainDatas, res | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
class MoveTimeDataset(Dataset): | ||
def __init__(self, path, skipFrame = 10): | ||
self.data = [] | ||
self.output = [] | ||
files = get_all_files(path) | ||
for file in files: | ||
all_data = parse_file(file) | ||
data, output = parse_data(all_data, skipFrame) | ||
self.data.extend(data) | ||
self.output.extend(output) | ||
def __len__(self): | ||
return len(self.data) | ||
def __getitem__(self, idx): | ||
input = torch.tensor(self.data[idx], dtype=torch.float32) | ||
output = torch.tensor(self.output[idx], dtype=torch.float32) | ||
return input, output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import sys, os | ||
sys.path.append(os.path.dirname(__file__)) | ||
from MoveTimeDataset import MoveTimeDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import numpy as np | ||
class Vec: | ||
def __init__(self, x, y): | ||
self.x = x | ||
self.y = y | ||
def __add__(self, other): | ||
return Vec(self.x + other.x, self.y + other.y) | ||
def __sub__(self, other): | ||
return Vec(self.x - other.x, self.y - other.y) | ||
def __mul__(self, other): | ||
return Vec(self.x * other, self.y * other) | ||
def __truediv__(self, other): | ||
return Vec(self.x / other, self.y / other) | ||
def __str__(self): | ||
return f"({self.x}, {self.y})" | ||
def __repr__(self): | ||
return f"({self.x}, {self.y})" | ||
def __iter__(self): | ||
return iter([self.x, self.y]) | ||
def __getitem__(self, index): | ||
return [self.x, self.y][index] | ||
@property | ||
def mod(self): | ||
return np.sqrt(self.x**2 + self.y**2) | ||
@mod.setter | ||
def mod(self, value): | ||
self.x = self.x / self.mod * value | ||
self.y = self.y / self.mod * value | ||
@property | ||
def dir(self): | ||
return np.arctan2(self.y, self.x) | ||
@dir.setter | ||
def dir(self, value): | ||
_mod = self.mod | ||
self.x = np.cos(value) * _mod | ||
self.y = np.sin(value) * _mod | ||
def _rotate(self, radians): | ||
x, y = self.x, self.y | ||
self.x = x * np.cos(radians) - y * np.sin(radians) | ||
self.y = x * np.sin(radians) + y * np.cos(radians) | ||
def rotate(self, radians): | ||
newVec = Vec(self.x, self.y) | ||
newVec._rotate(radians) | ||
return newVec | ||
|
||
if __name__ == "__main__": | ||
v = Vec(1,1) | ||
print(v) | ||
v._rotate(np.pi/2) | ||
print(v) | ||
v._rotate(-v.dir) | ||
print(v) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from rocos.dataset import MoveTimeDataset | ||
from torch.utils.data import DataLoader | ||
import torch | ||
from torch import nn | ||
from torch.utils.tensorboard import SummaryWriter | ||
writer = SummaryWriter() | ||
import matplotlib.pyplot as plt | ||
device = ( | ||
"cuda" | ||
if torch.cuda.is_available() | ||
else "mps" | ||
if torch.backends.mps.is_available() | ||
else "cpu" | ||
) | ||
print(f"Using {device} device") | ||
class MyModel(nn.Module): | ||
def __init__(self, n_feature, n_hidden, n_output, p=0.1): | ||
super().__init__() | ||
self.linear_relu_stack = nn.Sequential( | ||
nn.Linear(n_feature, n_hidden), | ||
nn.LeakyReLU(), | ||
nn.Dropout(p=p), | ||
nn.Linear(n_hidden, n_hidden), | ||
nn.LeakyReLU(), | ||
nn.Dropout(p=p), | ||
nn.Linear(n_hidden, n_output), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.linear_relu_stack(x) | ||
return x | ||
|
||
EPOCH = 10000 | ||
|
||
if __name__ == '__main__': | ||
from datetime import datetime | ||
SYMBOL = str(datetime.now()) | ||
train_dataset = MoveTimeDataset('../__data/robot_run/train') | ||
test_dataset = MoveTimeDataset('../__data/robot_run/test', skipFrame=30) | ||
print("train dataset length : ", len(train_dataset)) | ||
print("test dataset length : ", len(test_dataset)) | ||
|
||
train_dataloader = DataLoader(train_dataset,batch_size=32, shuffle=True) | ||
test_dataloader = DataLoader(test_dataset,batch_size=32, shuffle=True) | ||
|
||
model = MyModel(3, 32, 1).to(device) | ||
print(model) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | ||
loss_func = torch.nn.MSELoss() | ||
|
||
losses = [] | ||
for epoch in range(EPOCH): | ||
sum_loss = 0.0 | ||
for i, (x,y) in enumerate(train_dataloader): | ||
x, y = x.to(device), y.to(device) | ||
output = model(x) | ||
loss = loss_func(output, y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
sum_loss += loss.data.cpu().numpy() | ||
sum_loss = sum_loss / len(train_dataset) | ||
writer.add_scalar("Loss/train", sum_loss, epoch) | ||
if epoch % 20 == 0: | ||
print(f"Train Loss : {sum_loss}") | ||
with torch.no_grad(): | ||
sum_loss = 0.0 | ||
for i, (x,y) in enumerate(test_dataloader): | ||
x, y = x.to(device), y.to(device) | ||
output = model(x) | ||
loss = loss_func(output, y) | ||
sum_loss += loss.data.cpu().numpy() | ||
sum_loss = sum_loss / len(test_dataset) | ||
writer.add_scalar("Loss/test", sum_loss, epoch) | ||
print(f"Test Loss : {sum_loss}") | ||
if epoch % 100 == 0: | ||
torch.save(model.state_dict(), "model/" + f'{SYMBOL}_model_{epoch//100}.pth') | ||
plt.plot(losses) | ||
plt.show() |
Oops, something went wrong.