Skip to content

Commit

Permalink
support batch size 1
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesq34 committed Aug 8, 2017
1 parent ee75927 commit 4afd46d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models/pointnet_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
with tf.variable_scope('transform_net2') as sc:
transform = feature_transform_net(net, is_training, bn_decay, K=64)
end_points['transform'] = transform
net_transformed = tf.matmul(tf.squeeze(net), transform)
net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform)
net_transformed = tf.expand_dims(net_transformed, [2])

net = tf_util.conv2d(net_transformed, 64, [1,1],
Expand Down
2 changes: 1 addition & 1 deletion models/pointnet_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
with tf.variable_scope('transform_net2') as sc:
transform = feature_transform_net(net, is_training, bn_decay, K=64)
end_points['transform'] = transform
net_transformed = tf.matmul(tf.squeeze(net), transform)
net_transformed = tf.matmul(tf.squeeze(net, axis=[2]), transform)
point_feat = tf.expand_dims(net_transformed, [2])
print(point_feat)

Expand Down

0 comments on commit 4afd46d

Please sign in to comment.