Skip to content

Commit

Permalink
update vaes
Browse files Browse the repository at this point in the history
  • Loading branch information
DorinDaniil committed Nov 26, 2024
1 parent 1430504 commit 6175a1b
Show file tree
Hide file tree
Showing 18 changed files with 460 additions and 821 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 17 additions & 4 deletions demo/visualization.ipynb → demo/demo.ipynb

Large diffs are not rendered by default.

75 changes: 52 additions & 23 deletions demo/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,56 +9,84 @@
import torch.optim as optim
from torch.distributions import Categorical


parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)')
args = parser.parse_args()
def parse_arguments() -> argparse.Namespace:
"""
Parse command line arguments.
Returns:
argparse.Namespace: Parsed command line arguments.
"""
parser = argparse.ArgumentParser(description='PyTorch REINFORCE example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N',
help='random seed (default: 543)')
parser.add_argument('--render', action='store_true',
help='render the environment')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='interval between training status logs (default: 10)')
return parser.parse_args()

args = parse_arguments()

env = gym.make('Acrobot-v1')
env.reset(seed=args.seed)
torch.manual_seed(args.seed)


class Policy(nn.Module):
def __init__(self):
"""
Policy network for the REINFORCE algorithm.
"""
def __init__(self) -> None:
super(Policy, self).__init__()
self.affine1 = nn.Linear(6, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 3)

self.saved_log_probs = []
self.rewards = []
self.saved_log_probs: list[torch.Tensor] = []
self.rewards: list[float] = []

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the policy network.
def forward(self, x):
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Action probabilities.
"""
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)


policy = Policy()
optimizer = optim.Adam(policy.parameters(), lr=1e-2)
eps = np.finfo(np.float32).eps.item()

def select_action(state: np.ndarray) -> int:
"""
Select an action based on the current state using the policy network.
def select_action(state):
Args:
state (np.ndarray): Current state.
Returns:
int: Selected action.
"""
state = torch.from_numpy(state).float().unsqueeze(0)
probs = policy(state)
m = Categorical(probs)
action = m.sample()
policy.saved_log_probs.append(m.log_prob(action))
return action.item()


def finish_episode():
def finish_episode() -> None:
"""
Finish the episode by updating the policy network.
"""
R = 0
policy_loss = []
returns = deque()
Expand All @@ -76,8 +104,10 @@ def finish_episode():
del policy.rewards[:]
del policy.saved_log_probs[:]


def main():
def main() -> None:
"""
Main function to run the REINFORCE algorithm.
"""
running_reward = 10
for i_episode in count(1):
state, _ = env.reset()
Expand All @@ -102,6 +132,5 @@ def main():
"the last episode runs to {} time steps!".format(running_reward, t))
break


if __name__ == '__main__':
main()
137 changes: 0 additions & 137 deletions demo/vae.py

This file was deleted.

Loading

0 comments on commit 6175a1b

Please sign in to comment.