-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknapsack_env_linear.py
102 lines (80 loc) · 4.43 KB
/
knapsack_env_linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import torch
import torch.nn as nn
class MultiKnapsackEnv(gym.Env):
def __init__(self, items, capacities):
super(MultiKnapsackEnv, self).__init__()
self.items = np.array(items) # 아이템 (값, 무게)
self.capacities = np.array(capacities) # 배낭의 용량
self.num_items = len(items)
self.num_bags = len(capacities)
#For Normalization
# self.max_item_val = np.max(self.items[:, 0])
# self.items[:, 0] /= self.max_item_val
# self.max_item_weight = np.max(self.items[:, 1])
# self.items[:, 1] /= self.max_item_weight
# self.capacities /= self.max_item_weight
# 행동 공간 (각 아이템을 각 배낭에 넣을 수 있는 선택지)
self.action_space = spaces.Discrete(self.num_items * self.num_bags)
# 상태 공간
# 아이템의 값, 무게, 배낭의 남은 용량, 아이템의 선택 상태
self.observation_space = spaces.Dict({
'item_values': spaces.Box(low=0, high=np.inf, shape=(self.num_items,), dtype=np.float32),
'item_weights': spaces.Box(low=0, high=np.inf, shape=(self.num_items,), dtype=np.float32),
'remaining_capacities': spaces.Box(low=0, high=np.max(capacities), shape=(self.num_bags,), dtype=np.float32),
'selection_status': spaces.MultiBinary(self.num_items)
})
self.time_step = -1
self.state = None
self.reset()
def prepare_state(self):
item_values = np.array(self.state['item_values'], dtype=np.float32)
item_weights = np.array(self.state['item_weights'], dtype=np.float32)
remaining_capacities = np.array(self.state['remaining_capacities'], dtype=np.float32)
selection_status = np.array(self.state['selection_status'], dtype=np.float32)
ret = np.concatenate([item_values, item_weights, remaining_capacities, selection_status])
ret = ret.reshape(1, len(ret))
return ret
# item_values = torch.as_tensor(self.state['item_values'], dtype=torch.float32)
# item_weights = torch.as_tensor(self.state['item_weights'], dtype=torch.float32)
# remaining_capacities = torch.as_tensor(self.state['remaining_capacities'], dtype=torch.float32)
# selection_status = torch.as_tensor(self.state['selection_status'], dtype=torch.float32)
# ret = torch.cat([item_values, item_weights, remaining_capacities, selection_status])
# return ret.reshape(1, len(ret))
def reset(self):
self.state = {
'item_values': self.items[:, 0],
'item_weights': self.items[:, 1],
'remaining_capacities': self.capacities.copy(),
'selection_status': np.zeros(self.num_items, dtype=int)
}
self.time_step = -1
return self.prepare_state()
def step(self, action):
item_idx = action % self.num_items
bag_idx = action // self.num_items
item_value = self.items[item_idx, 0]
item_weight = self.items[item_idx, 1]
if self.state['selection_status'][item_idx] == 0 and self.state['remaining_capacities'][bag_idx] >= item_weight:
self.state['remaining_capacities'][bag_idx] -= item_weight
self.state['selection_status'][item_idx] = 1
reward = item_value #how do we normalize?
else:
reward = 0 # 아이템을 추가할 수 없는 경우 보상 없음
# 모든 아이템이 선택되었는지, 또는 더 이상 아이템을 추가할 수 없는지 확인
no_more_fits = np.all(self.state['remaining_capacities'] < np.min(self.items[self.state['selection_status'] == 0, 1]))
done = np.all(self.state['selection_status']) or no_more_fits
mask = self.valid_actions()
# self.time_step += 1
# if self.time_step > 100:
# done = True
return self.prepare_state(), reward, done, mask
def valid_actions(self):
mask = np.zeros(self.num_items * self.num_bags, dtype=bool)
for i in range(self.num_items):
for j in range(self.num_bags):
if self.state["selection_status"][i] == 0 and self.state['remaining_capacities'][j] >= self.items[i][1]:
mask[i + j * self.num_items] = True
return mask.reshape(1, len(mask))