-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathrenderer_ogl.py
197 lines (153 loc) · 6.23 KB
/
renderer_ogl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from OpenGL import GL as gl
import util
import util_gau
import numpy as np
try:
from OpenGL.raw.WGL.EXT.swap_control import wglSwapIntervalEXT
except:
wglSwapIntervalEXT = None
_sort_buffer_xyz = None
_sort_buffer_gausid = None # used to tell whether gaussian is reloaded
def _sort_gaussian_cpu(gaus, view_mat):
xyz = np.asarray(gaus.xyz)
view_mat = np.asarray(view_mat)
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = np.argsort(depth)
index = index.astype(np.int32).reshape(-1, 1)
return index
def _sort_gaussian_cupy(gaus, view_mat):
import cupy as cp
global _sort_buffer_gausid, _sort_buffer_xyz
if _sort_buffer_gausid != id(gaus):
_sort_buffer_xyz = cp.asarray(gaus.xyz)
_sort_buffer_gausid = id(gaus)
xyz = _sort_buffer_xyz
view_mat = cp.asarray(view_mat)
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = cp.argsort(depth)
index = index.astype(cp.int32).reshape(-1, 1)
index = cp.asnumpy(index) # convert to numpy
return index
def _sort_gaussian_torch(gaus, view_mat):
global _sort_buffer_gausid, _sort_buffer_xyz
if _sort_buffer_gausid != id(gaus):
_sort_buffer_xyz = torch.tensor(gaus.xyz).cuda()
_sort_buffer_gausid = id(gaus)
xyz = _sort_buffer_xyz
view_mat = torch.tensor(view_mat).cuda()
xyz_view = view_mat[None, :3, :3] @ xyz[..., None] + view_mat[None, :3, 3, None]
depth = xyz_view[:, 2, 0]
index = torch.argsort(depth)
index = index.type(torch.int32).reshape(-1, 1).cpu().numpy()
return index
# Decide which sort to use
_sort_gaussian = None
try:
import torch
if not torch.cuda.is_available():
raise ImportError
print("Detect torch cuda installed, will use torch as sorting backend")
_sort_gaussian = _sort_gaussian_torch
except ImportError:
try:
import cupy as cp
print("Detect cupy installed, will use cupy as sorting backend")
_sort_gaussian = _sort_gaussian_cupy
except ImportError:
_sort_gaussian = _sort_gaussian_cpu
class GaussianRenderBase:
def __init__(self):
self.gaussians = None
self._reduce_updates = True
@property
def reduce_updates(self):
return self._reduce_updates
@reduce_updates.setter
def reduce_updates(self, val):
self._reduce_updates = val
self.update_vsync()
def update_vsync(self):
print("VSync is not supported")
def update_gaussian_data(self, gaus: util_gau.GaussianData):
raise NotImplementedError()
def sort_and_update(self):
raise NotImplementedError()
def set_scale_modifier(self, modifier: float):
raise NotImplementedError()
def set_render_mod(self, mod: int):
raise NotImplementedError()
def update_camera_pose(self, camera: util.Camera):
raise NotImplementedError()
def update_camera_intrin(self, camera: util.Camera):
raise NotImplementedError()
def draw(self):
raise NotImplementedError()
def set_render_reso(self, w, h):
raise NotImplementedError()
class OpenGLRenderer(GaussianRenderBase):
def __init__(self, w, h):
super().__init__()
gl.glViewport(0, 0, w, h)
self.program = util.load_shaders('shaders/gau_vert.glsl', 'shaders/gau_frag.glsl')
# Vertex data for a quad
self.quad_v = np.array([
-1, 1,
1, 1,
1, -1,
-1, -1
], dtype=np.float32).reshape(4, 2)
self.quad_f = np.array([
0, 1, 2,
0, 2, 3
], dtype=np.uint32).reshape(2, 3)
# load quad geometry
vao, buffer_id = util.set_attributes(self.program, ["position"], [self.quad_v])
util.set_faces_tovao(vao, self.quad_f)
self.vao = vao
self.gau_bufferid = None
self.index_bufferid = None
# opengl settings
gl.glDisable(gl.GL_CULL_FACE)
gl.glEnable(gl.GL_BLEND)
gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA)
self.update_vsync()
def update_vsync(self):
if wglSwapIntervalEXT is not None:
wglSwapIntervalEXT(1 if self.reduce_updates else 0)
else:
print("VSync is not supported")
def update_gaussian_data(self, gaus: util_gau.GaussianData):
self.gaussians = gaus
# load gaussian geometry
gaussian_data = gaus.flat()
self.gau_bufferid = util.set_storage_buffer_data(self.program, "gaussian_data", gaussian_data,
bind_idx=0,
buffer_id=self.gau_bufferid)
util.set_uniform_1int(self.program, gaus.sh_dim, "sh_dim")
def sort_and_update(self, camera: util.Camera):
index = _sort_gaussian(self.gaussians, camera.get_view_matrix())
self.index_bufferid = util.set_storage_buffer_data(self.program, "gi", index,
bind_idx=1,
buffer_id=self.index_bufferid)
return
def set_scale_modifier(self, modifier):
util.set_uniform_1f(self.program, modifier, "scale_modifier")
def set_render_mod(self, mod: int):
util.set_uniform_1int(self.program, mod, "render_mod")
def set_render_reso(self, w, h):
gl.glViewport(0, 0, w, h)
def update_camera_pose(self, camera: util.Camera):
view_mat = camera.get_view_matrix()
util.set_uniform_mat4(self.program, view_mat, "view_matrix")
util.set_uniform_v3(self.program, camera.position, "cam_pos")
def update_camera_intrin(self, camera: util.Camera):
proj_mat = camera.get_project_matrix()
util.set_uniform_mat4(self.program, proj_mat, "projection_matrix")
util.set_uniform_v3(self.program, camera.get_htanfovxy_focal(), "hfovxy_focal")
def draw(self):
gl.glUseProgram(self.program)
gl.glBindVertexArray(self.vao)
num_gau = len(self.gaussians)
gl.glDrawElementsInstanced(gl.GL_TRIANGLES, len(self.quad_f.reshape(-1)), gl.GL_UNSIGNED_INT, None, num_gau)