Skip to content

Commit 9e66324

Browse files
authored
Remove wp.Volume (#683)
* remove wp.Volume * auto-format * fix sensor * fix kernel arguments * fix small bugs * fix indentation * add box_project function * use analytic gradient where available * remove redundant aabb check * fix test requirement
1 parent e29f928 commit 9e66324

File tree

5 files changed

+123
-146
lines changed

5 files changed

+123
-146
lines changed

mujoco_warp/_src/collision_sdf.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .types import GeomType
2828
from .types import Model
2929
from .types import vec5
30+
from .types import vec8f
31+
from .types import vec8i
3032
from .util_misc import halton
3133
from .warp_util import event_scope
3234

@@ -47,9 +49,12 @@ class AABB:
4749

4850
@wp.struct
4951
class VolumeData:
50-
volume_id: wp.uint64
5152
center: wp.vec3
5253
half_size: wp.vec3
54+
oct_aabb: wp.array2d(dtype=wp.vec3)
55+
oct_child: wp.array(dtype=vec8i)
56+
oct_coeff: wp.array(dtype=vec8f)
57+
valid: bool = False
5358

5459

5560
@wp.struct
@@ -68,20 +73,12 @@ class MeshData:
6873
valid: bool = False
6974

7075

71-
@wp.func
72-
def get_volume_data(volume_id: wp.uint64, center: wp.vec3, half_size: wp.vec3) -> VolumeData:
73-
volume_data = VolumeData()
74-
volume_data.volume_id = volume_id
75-
volume_data.center = center
76-
volume_data.half_size = half_size
77-
return volume_data
78-
79-
8076
@wp.func
8177
def get_sdf_params(
8278
# Model:
83-
volume_ids: wp.array(dtype=wp.uint64),
8479
oct_aabb: wp.array2d(dtype=wp.vec3),
80+
oct_child: wp.array(dtype=vec8i),
81+
oct_coeff: wp.array(dtype=vec8f),
8582
plugin: wp.array(dtype=int),
8683
plugin_attr: wp.array(dtype=wp.vec3f),
8784
# In:
@@ -92,17 +89,19 @@ def get_sdf_params(
9289
) -> Tuple[wp.vec3, int, VolumeData, MeshData]:
9390
attributes = g_size
9491
plugin_index = -1
95-
volume_data = get_volume_data(wp.uint64(0), wp.vec3(0.0), wp.vec3(0.0))
92+
volume_data = VolumeData()
9693

9794
if g_type == int(GeomType.SDF.value) and plugin_id != -1:
9895
attributes = plugin_attr[plugin_id]
9996
plugin_index = plugin[plugin_id]
10097

10198
elif g_type == int(GeomType.SDF.value) and mesh_id != -1:
102-
volume_id = volume_ids[mesh_id]
103-
center = oct_aabb[mesh_id, 0]
104-
half_size = oct_aabb[mesh_id, 1]
105-
volume_data = get_volume_data(volume_id, center, half_size)
99+
volume_data.center = oct_aabb[mesh_id, 0]
100+
volume_data.half_size = oct_aabb[mesh_id, 1]
101+
volume_data.oct_aabb = oct_aabb
102+
volume_data.oct_child = oct_child
103+
volume_data.oct_coeff = oct_coeff
104+
volume_data.valid = True
106105

107106
return attributes, plugin_index, volume_data, MeshData()
108107

@@ -223,22 +222,73 @@ def user_sdf_grad(p: wp.vec3, attr: wp.vec3, sdf_type: int) -> wp.vec3:
223222

224223

225224
@wp.func
226-
def sample_volume_sdf(xyz: wp.vec3, volume_data: VolumeData) -> float:
227-
center = volume_data.center
228-
half_size = volume_data.half_size
225+
def find_oct(
226+
oct_aabb: wp.array2d(dtype=wp.vec3), oct_child: wp.array(dtype=vec8i), p: wp.vec3, grad: bool
227+
) -> Tuple[int, Tuple[vec8f, vec8f, vec8f]]:
228+
stack = int(0)
229+
niter = int(100)
230+
rx = vec8f(0.0)
231+
ry = vec8f(0.0)
232+
rz = vec8f(0.0)
233+
234+
while niter > 0:
235+
niter -= 1
236+
node = stack
237+
238+
if node == -1:
239+
wp.printf("ERROR: Invalid node number\n")
240+
return -1, (rx, ry, rz)
241+
242+
vmin = oct_aabb[node, 0] - oct_aabb[node, 1]
243+
vmax = oct_aabb[node, 0] + oct_aabb[node, 1]
244+
coord = wp.cw_div(p - vmin, vmax - vmin)
245+
246+
# check if the node is a leaf
247+
if (
248+
oct_child[node][0] == -1
249+
and oct_child[node][1] == -1
250+
and oct_child[node][2] == -1
251+
and oct_child[node][3] == -1
252+
and oct_child[node][4] == -1
253+
and oct_child[node][5] == -1
254+
and oct_child[node][6] == -1
255+
and oct_child[node][7] == -1
256+
):
257+
for j in range(8):
258+
if not grad:
259+
rx[j] = (
260+
(coord[0] if j & 1 else 1.0 - coord[0])
261+
* (coord[1] if j & 2 else 1.0 - coord[1])
262+
* (coord[2] if j & 4 else 1.0 - coord[2])
263+
)
264+
else:
265+
rx[j] = (1.0 if j & 1 else -1.0) * (coord[1] if j & 2 else 1.0 - coord[1]) * (coord[2] if j & 4 else 1.0 - coord[2])
266+
ry[j] = (coord[0] if j & 1 else 1.0 - coord[0]) * (1.0 if j & 2 else -1.0) * (coord[2] if j & 4 else 1.0 - coord[2])
267+
rz[j] = (coord[0] if j & 1 else 1.0 - coord[0]) * (coord[1] if j & 2 else 1.0 - coord[1]) * (1.0 if j & 4 else -1.0)
268+
return node, (rx, ry, rz)
269+
270+
# compute which of 8 children to visit next
271+
x = 1 if coord[0] < 0.5 else 0
272+
y = 1 if coord[1] < 0.5 else 0
273+
z = 1 if coord[2] < 0.5 else 0
274+
stack = oct_child[node][4 * z + 2 * y + x]
275+
276+
wp.print("ERROR: Node not found\n")
277+
return -1, (rx, ry, rz)
229278

279+
280+
@wp.func
281+
def box_project(center: wp.vec3, half_size: wp.vec3, xyz: wp.vec3) -> Tuple[float, wp.vec3]:
230282
r = xyz - center
231283
q = wp.vec3(wp.abs(r[0]) - half_size[0], wp.abs(r[1]) - half_size[1], wp.abs(r[2]) - half_size[2])
232284

233285
if q[0] <= 0.0 and q[1] <= 0.0 and q[2] <= 0.0:
234-
uvw = wp.volume_world_to_index(volume_data.volume_id, xyz)
235-
sdf = wp.volume_sample_f(volume_data.volume_id, uvw, wp.Volume.LINEAR)
236-
return sdf
286+
return 0.0, xyz
237287

238288
else:
239-
point = wp.vec3(xyz[0], xyz[1], xyz[2])
240289
dist_sqr = 0.0
241290
eps = 1e-4
291+
point = wp.vec3(xyz[0], xyz[1], xyz[2])
242292

243293
if q[0] >= 0.0:
244294
dist_sqr += q[0] * q[0]
@@ -261,23 +311,33 @@ def sample_volume_sdf(xyz: wp.vec3, volume_data: VolumeData) -> float:
261311
else:
262312
point = wp.vec3(point[0], point[1], point[2] + (q[2] + eps))
263313

264-
dist0 = wp.sqrt(dist_sqr)
314+
return wp.sqrt(dist_sqr), point
265315

266-
uvw = wp.volume_world_to_index(volume_data.volume_id, point)
267-
sdf = wp.volume_sample_f(volume_data.volume_id, uvw, wp.Volume.LINEAR)
268-
return dist0 + sdf
316+
317+
@wp.func
318+
def sample_volume_sdf(xyz: wp.vec3, volume_data: VolumeData) -> float:
319+
dist0, point = box_project(volume_data.center, volume_data.half_size, xyz)
320+
node, weights = find_oct(volume_data.oct_aabb, volume_data.oct_child, point, grad=False)
321+
return dist0 + wp.dot(weights[0], volume_data.oct_coeff[node])
269322

270323

271324
@wp.func
272325
def sample_volume_grad(xyz: wp.vec3, volume_data: VolumeData) -> wp.vec3:
273-
h = 1e-4
274-
dx = wp.vec3(h, 0.0, 0.0)
275-
dy = wp.vec3(0.0, h, 0.0)
276-
dz = wp.vec3(0.0, 0.0, h)
277-
f = sample_volume_sdf(xyz, volume_data)
278-
grad_x = (sample_volume_sdf(xyz + dx, volume_data) - f) / h
279-
grad_y = (sample_volume_sdf(xyz + dy, volume_data) - f) / h
280-
grad_z = (sample_volume_sdf(xyz + dz, volume_data) - f) / h
326+
dist0, point = box_project(volume_data.center, volume_data.half_size, xyz)
327+
if dist0 > 0:
328+
h = 1e-4
329+
dx = wp.vec3(h, 0.0, 0.0)
330+
dy = wp.vec3(0.0, h, 0.0)
331+
dz = wp.vec3(0.0, 0.0, h)
332+
f = sample_volume_sdf(xyz, volume_data)
333+
grad_x = (sample_volume_sdf(xyz + dx, volume_data) - f) / h
334+
grad_y = (sample_volume_sdf(xyz + dy, volume_data) - f) / h
335+
grad_z = (sample_volume_sdf(xyz + dz, volume_data) - f) / h
336+
return wp.vec3(grad_x, grad_y, grad_z)
337+
node, weights = find_oct(volume_data.oct_aabb, volume_data.oct_child, point, grad=True)
338+
grad_x = wp.dot(weights[0], volume_data.oct_coeff[node])
339+
grad_y = wp.dot(weights[1], volume_data.oct_coeff[node])
340+
grad_z = wp.dot(weights[2], volume_data.oct_coeff[node])
281341
return wp.vec3(grad_x, grad_y, grad_z)
282342

283343

@@ -579,8 +639,9 @@ def _sdf_narrowphase(
579639
mesh_polymapadr: wp.array(dtype=int),
580640
mesh_polymapnum: wp.array(dtype=int),
581641
mesh_polymap: wp.array(dtype=int),
582-
volume_ids: wp.array(dtype=wp.uint64),
583642
oct_aabb: wp.array2d(dtype=wp.vec3),
643+
oct_child: wp.array(dtype=vec8i),
644+
oct_coeff: wp.array(dtype=vec8f),
584645
pair_dim: wp.array(dtype=int),
585646
pair_solref: wp.array2d(dtype=wp.vec2),
586647
pair_solreffriction: wp.array2d(dtype=wp.vec2),
@@ -732,11 +793,11 @@ def _sdf_narrowphase(
732793
rot1 = geom1.rot
733794

734795
attr1, g1_plugin_id, volume_data1, mesh_data1 = get_sdf_params(
735-
volume_ids, oct_aabb, plugin, plugin_attr, type1, geom1.size, g1_plugin, geom_dataid[g1]
796+
oct_aabb, oct_child, oct_coeff, plugin, plugin_attr, type1, geom1.size, g1_plugin, geom_dataid[g1]
736797
)
737798

738799
attr2, g2_plugin_id, volume_data2, mesh_data2 = get_sdf_params(
739-
volume_ids, oct_aabb, plugin, plugin_attr, type2, geom2.size, g2_plugin, geom_dataid[g2]
800+
oct_aabb, oct_child, oct_coeff, plugin, plugin_attr, type2, geom2.size, g2_plugin, geom_dataid[g2]
740801
)
741802

742803
mesh_data1.nmeshface = nmeshface
@@ -857,8 +918,9 @@ def sdf_narrowphase(m: Model, d: Data):
857918
m.mesh_polymapadr,
858919
m.mesh_polymapnum,
859920
m.mesh_polymap,
860-
m.volume_ids,
861921
m.oct_aabb,
922+
m.oct_child,
923+
m.oct_coeff,
862924
m.pair_dim,
863925
m.pair_solref,
864926
m.pair_solreffriction,

mujoco_warp/_src/io.py

Lines changed: 3 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def put_model(mjm: mujoco.MjModel) -> types.Model:
9696

9797
plugin_id = np.array(plugin_id)
9898
plugin_attr = np.array(plugin_attr)
99-
volume_ids, volumes, oct_aabb = _mujoco_octree_to_warp_volume(mjm)
10099

101100
if mjm.nflex > 1:
102101
raise NotImplementedError("Only one flex is unsupported.")
@@ -637,9 +636,9 @@ def create_nmodel_batched_array(mjm_array, dtype, expand_dim=True):
637636
mesh_polymapadr=wp.array(mjm.mesh_polymapadr, dtype=int),
638637
mesh_polymapnum=wp.array(mjm.mesh_polymapnum, dtype=int),
639638
mesh_polymap=wp.array(mjm.mesh_polymap, dtype=int),
640-
volume_ids=volume_ids,
641-
volumes=volumes,
642-
oct_aabb=oct_aabb,
639+
oct_aabb=wp.array2d(mjm.oct_aabb, dtype=wp.vec3),
640+
oct_child=wp.array(mjm.oct_child, dtype=types.vec8i),
641+
oct_coeff=wp.array(mjm.oct_coeff, dtype=types.vec8f),
643642
nhfield=mjm.nhfield,
644643
nhfielddata=mjm.nhfielddata,
645644
hfield_adr=wp.array(mjm.hfield_adr, dtype=int),
@@ -858,85 +857,6 @@ def create_nmodel_batched_array(mjm_array, dtype, expand_dim=True):
858857
return m
859858

860859

861-
def _mujoco_octree_to_warp_volume(
862-
mjm: mujoco.MjModel, resolution: int = 128
863-
) -> Tuple[wp.array, Tuple[wp.Volume, ...], wp.array]:
864-
"""Constructs volume data from MuJoCo octrees."""
865-
volume_ids = [0] * len(mjm.mesh_octadr)
866-
volumes = []
867-
oct_aabbs = [None] * len(mjm.mesh_octadr)
868-
for mesh_id in mjm.geom_dataid:
869-
if mesh_id != -1 and mesh_id < len(mjm.mesh_octadr):
870-
octadr = mjm.mesh_octadr[mesh_id]
871-
octnum = mjm.mesh_octnum[mesh_id]
872-
if octadr != -1:
873-
oct_child = mjm.oct_child[octadr : (octadr + octnum), :]
874-
oct_aabb = mjm.oct_aabb[octadr : (octadr + octnum), :]
875-
oct_coeff = mjm.oct_coeff[octadr : (octadr + octnum), :]
876-
877-
root_aabb = oct_aabb[0]
878-
center = root_aabb[:3]
879-
half_size = root_aabb[3:]
880-
881-
original_mins = center - half_size
882-
original_maxs = center + half_size
883-
884-
margin_factor = 0.02
885-
extents = original_maxs - original_mins
886-
margin = margin_factor * extents
887-
888-
mins = original_mins - margin
889-
maxs = original_maxs + margin
890-
expanded_extents = maxs - mins
891-
892-
voxel_size = expanded_extents.max() / resolution
893-
894-
nums = np.ceil(expanded_extents / voxel_size).astype(dtype=int)
895-
896-
actual_extents = nums * voxel_size
897-
maxs = mins + actual_extents
898-
899-
sdf_values = np.zeros(tuple(nums), dtype=np.float32)
900-
901-
for x in range(nums[0]):
902-
for y in range(nums[1]):
903-
for z in range(nums[2]):
904-
pos = mins + voxel_size * np.array([x, y, z])
905-
within_bounds = np.all(pos >= original_mins) and np.all(pos <= original_maxs)
906-
if within_bounds:
907-
sdf_val = sample_octree_sdf(pos, oct_child, oct_aabb, oct_coeff)
908-
else:
909-
clamped_pos = np.clip(pos, original_mins, original_maxs)
910-
sdf_val = sample_octree_sdf(clamped_pos, oct_child, oct_aabb, oct_coeff)
911-
sdf_values[x, y, z] = sdf_val
912-
913-
device = wp.get_device()
914-
if device.is_cuda:
915-
volume = wp.Volume.load_from_numpy(sdf_values, mins, voxel_size, 1.0, device=device)
916-
volume_ids[mesh_id] = volume.id
917-
volumes.append(volume)
918-
else:
919-
volume_ids[mesh_id] = 0
920-
oct_aabbs[mesh_id] = [center, half_size]
921-
922-
volume_ids_array = wp.array(data=volume_ids, dtype=wp.uint64)
923-
924-
processed_aabbs = []
925-
for aabb in oct_aabbs:
926-
if aabb is not None:
927-
processed_aabbs.append(aabb)
928-
else:
929-
zero_center = np.zeros(3, dtype=np.float32)
930-
zero_half_size = np.zeros(3, dtype=np.float32)
931-
processed_aabbs.append([zero_center, zero_half_size])
932-
933-
aabb_array = np.array(processed_aabbs, dtype=np.float32)
934-
oct_aabb_array = wp.array2d(data=aabb_array, dtype=wp.vec3)
935-
volumes_tuple = tuple(volumes)
936-
937-
return volume_ids_array, volumes_tuple, oct_aabb_array
938-
939-
940860
def make_data(mjm: mujoco.MjModel, nworld: int = 1, nconmax: int = -1, njmax: int = -1) -> types.Data:
941861
"""
942862
Creates a data object on device.

mujoco_warp/_src/io_test.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,6 @@ def _check_annotation_compat(
130130
if isinstance(v, wp.types.array):
131131
continue
132132

133-
if v == wp.Volume:
134-
continue
135-
136133
if v in wp.types.vector_types:
137134
raise AssertionError(f"Vector types are not allowed. {info}")
138135

@@ -392,30 +389,16 @@ def test_contact_sensor(self, contact_sensor):
392389
with self.assertRaises(NotImplementedError):
393390
mjwarp.put_model(mjm)
394391

395-
def test_volumes(self):
396-
"""Tests that mujoco_octree_to_warp_volume properly processes SDF volumes."""
397-
if not wp.get_device().is_cuda:
398-
self.skipTest("SDF volumes require CUDA device")
392+
def test_sdf(self):
393+
"""Tests that an SDF can be loaded."""
399394
mjm, mjd, m, d = test_util.fixture(fname="collision_sdf/cow.xml", qpos0=True)
400395

401-
self.assertIsInstance(m.volume_ids, wp.array)
402-
self.assertEqual(m.volume_ids.dtype, wp.uint64)
403-
self.assertGreater(m.volume_ids.size, 0)
404-
405-
self.assertIsInstance(m.volumes, tuple)
406-
if len(m.volumes) > 0:
407-
for volume in m.volumes:
408-
self.assertIsInstance(volume, wp.Volume)
409-
410396
self.assertIsInstance(m.oct_aabb, wp.array)
411397
self.assertEqual(m.oct_aabb.dtype, wp.vec3)
412398
self.assertEqual(len(m.oct_aabb.shape), 2)
413399
if m.oct_aabb.size > 0:
414400
self.assertEqual(m.oct_aabb.shape[1], 2)
415401

416-
volume_ids_numpy = m.volume_ids.numpy()
417-
self.assertEqual(len(m.volumes), np.unique(volume_ids_numpy).size)
418-
419402

420403
if __name__ == "__main__":
421404
wp.init()

0 commit comments

Comments
 (0)