Skip to content

Commit f72c67a

Browse files
committed
eval ready
1 parent 465202a commit f72c67a

File tree

4 files changed

+53
-43
lines changed

4 files changed

+53
-43
lines changed

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Download the pretrained models and training logs in [release](https://github.com
6969

7070
# :mag_right: Testing
7171

72-
See [test.ipynb](test.ipynb) for a simple view synthesis and depth prediction on 1 image.
72+
Example: [test_nerf-u.ipynb](test_nerf-u.ipynb) shows how NeRF-U successfully decomposes the scene into static and transient components.
7373

7474
Use [eval.py](eval.py) to create the whole sequence of moving views.
7575
E.g.
@@ -79,14 +79,22 @@ python eval.py \
7979
--dataset_name blender --scene_name lego \
8080
--img_wh 400 400 --N_importance 64 --ckpt_path $CKPT_PATH
8181
```
82-
**IMPORTANT** : Don't forget to add `--spheric_poses` if the model is trained under `--spheric` setting!
8382

8483
It will create folder `results/{dataset_name}/{scene_name}` and run inference on all test data, finally create a gif out of them.
8584

86-
85+
Example of lego scene using pretrained **NeRF-U** model under **occluder** condition: (PSNR=28.60, paper=23.47)
86+
![nerf-u](https://user-images.githubusercontent.com/11364490/105578186-a9933400-5dc1-11eb-8865-e276b581d8fd.gif)
8787

8888
# :warning: Notes on differences with the original repo
8989

90-
* The learning rate decay in the original repo is **by step**, which means it decreases every step, here I use learning rate decay **by epoch**, which means it changes only at the end of 1 epoch.
91-
* The validation image for LLFF dataset is chosen as the most centered image here, whereas the original repo chooses every 8th image.
92-
* The rendering spiral path is slightly different from the original repo (I use approximate values to simplify the code).
90+
* Network structure ([nerf.py](models/nerf.py)):
91+
* My base MLP uses 8 layers of 256 units as the original NeRF, while NeRF-W uses **512** units each.
92+
* My static head uses 1 layer as the original NeRF, while NeRF-W uses **4** layers.
93+
* I use **softplus** activation for sigma (reason explained [here](https://github.com/bmild/nerf/issues/29#issuecomment-765335765)) while NeRF-W uses **relu**.
94+
95+
* Training hyperparameters
96+
* I find larger `beta_min` achieves better result, so my default `beta_min` is `0.1` instead of `0.03` in the paper.
97+
* I add 3 to `beta_loss` (equation 13) to make it positive empirically.
98+
99+
* Evalutaion
100+
* The evaluation metric is computed on the **test** set, while NeRF evaluates on val and test combined.

eval.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,16 @@ def get_opts():
2323
default='/home/ubuntu/data/nerf_example_data/nerf_synthetic/lego',
2424
help='root directory of dataset')
2525
parser.add_argument('--dataset_name', type=str, default='blender',
26-
choices=['blender', 'llff'],
26+
choices=['blender'],
2727
help='which dataset to validate')
2828
parser.add_argument('--scene_name', type=str, default='test',
2929
help='scene name, used as output folder name')
30-
parser.add_argument('--split', type=str, default='test',
31-
help='test or test_train')
30+
parser.add_argument('--split', type=str, default='val',
31+
choices=['val', 'test', 'test_train'])
3232
parser.add_argument('--img_wh', nargs="+", type=int, default=[800, 800],
3333
help='resolution (img_w, img_h) of the image')
34-
parser.add_argument('--spheric_poses', default=False, action="store_true",
35-
help='whether images are taken in spheric poses (for llff)')
3634

35+
# original NeRF parameters
3736
parser.add_argument('--N_emb_xyz', type=int, default=10,
3837
help='number of xyz embedding frequencies')
3938
parser.add_argument('--N_emb_dir', type=int, default=4,
@@ -45,9 +44,19 @@ def get_opts():
4544
parser.add_argument('--use_disp', default=False, action="store_true",
4645
help='use disparity depth sampling')
4746

47+
# NeRF-W parameters
48+
parser.add_argument('--N_vocab', type=int, default=100,
49+
help='''number of vocabulary (number of images)
50+
in the dataset for nn.Embedding''')
51+
parser.add_argument('--encode_a', default=False, action="store_true",
52+
help='whether to encode appearance (NeRF-A)')
53+
parser.add_argument('--N_a', type=int, default=48,
54+
help='number of embeddings for appearance')
55+
parser.add_argument('--encode_t', default=False, action="store_true",
56+
help='whether to encode transient object (NeRF-U)')
4857
parser.add_argument('--N_tau', type=int, default=16,
4958
help='number of embeddings for transient objects')
50-
parser.add_argument('--beta_min', type=float, default=0.03,
59+
parser.add_argument('--beta_min', type=float, default=0.1,
5160
help='minimum color variance for each ray')
5261

5362
parser.add_argument('--chunk', type=int, default=32*1024*4,
@@ -56,12 +65,6 @@ def get_opts():
5665
parser.add_argument('--ckpt_path', type=str, required=True,
5766
help='pretrained checkpoint path to load')
5867

59-
parser.add_argument('--save_depth', default=False, action="store_true",
60-
help='whether to save depth prediction')
61-
parser.add_argument('--depth_format', type=str, default='pfm',
62-
choices=['pfm', 'bytes'],
63-
help='which format to save')
64-
6568
return parser.parse_args()
6669

6770

@@ -103,24 +106,32 @@ def batched_inference(models, embeddings,
103106
kwargs = {'root_dir': args.root_dir,
104107
'split': args.split,
105108
'img_wh': tuple(args.img_wh)}
106-
if args.dataset_name == 'llff':
107-
kwargs['spheric_poses'] = args.spheric_poses
108109
dataset = dataset_dict[args.dataset_name](**kwargs)
109110

110-
embedding_t = torch.nn.Embedding(200, args.N_tau)
111111
embedding_xyz = PosEmbedding(args.N_emb_xyz-1, args.N_emb_xyz)
112112
embedding_dir = PosEmbedding(args.N_emb_dir-1, args.N_emb_dir)
113-
nerf_coarse = NeRF('coarse')
114-
nerf_fine = NeRF('fine', beta_min=args.beta_min)
115-
load_ckpt(embedding_t, args.ckpt_path, model_name='embedding_t')
113+
embeddings = {'xyz': embedding_xyz, 'dir': embedding_dir}
114+
if args.encode_a:
115+
embedding_a = torch.nn.Embedding(args.N_vocab, args.N_a).cuda()
116+
load_ckpt(embedding_a, args.ckpt_path, model_name='embedding_a')
117+
embeddings['a'] = embedding_a
118+
if args.encode_t:
119+
embedding_t = torch.nn.Embedding(args.N_vocab, args.N_tau).cuda()
120+
load_ckpt(embedding_t, args.ckpt_path, model_name='embedding_t')
121+
embeddings['t'] = embedding_t
122+
123+
nerf_coarse = NeRF('coarse').cuda()
124+
nerf_fine = NeRF('fine',
125+
encode_appearance=args.encode_a,
126+
in_channels_a=args.N_a,
127+
encode_transient=args.encode_t,
128+
in_channels_t=args.N_tau,
129+
beta_min=args.beta_min).cuda()
130+
116131
load_ckpt(nerf_coarse, args.ckpt_path, model_name='nerf_coarse')
117132
load_ckpt(nerf_fine, args.ckpt_path, model_name='nerf_fine')
118-
embedding_t.cuda()
119-
nerf_coarse.cuda()
120-
nerf_fine.cuda()
121133

122134
models = {'coarse': nerf_coarse, 'fine': nerf_fine}
123-
embeddings = {'xyz': embedding_xyz, 'dir': embedding_dir, 't': embedding_t}
124135

125136
imgs, psnrs = [], []
126137
dir_name = f'results/{args.dataset_name}/{args.scene_name}'
@@ -137,15 +148,6 @@ def batched_inference(models, embeddings,
137148

138149
img_pred = results['rgb_fine'].view(h, w, 3).cpu().numpy()
139150

140-
if args.save_depth:
141-
depth_pred = results['depth_fine'].view(h, w).cpu().numpy()
142-
depth_pred = np.nan_to_num(depth_pred)
143-
if args.depth_format == 'pfm':
144-
save_pfm(os.path.join(dir_name, f'depth_{i:03d}.pfm'), depth_pred)
145-
else:
146-
with open(f'depth_{i:03d}', 'wb') as f:
147-
f.write(depth_pred.tobytes())
148-
149151
img_pred_ = (img_pred*255).astype(np.uint8)
150152
imgs += [img_pred_]
151153
imageio.imwrite(os.path.join(dir_name, f'{i:03d}.png'), img_pred_)

opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_opts():
4646
help='whether to encode transient object (NeRF-U)')
4747
parser.add_argument('--N_tau', type=int, default=16,
4848
help='number of embeddings for transient objects')
49-
parser.add_argument('--beta_min', type=float, default=0.03,
49+
parser.add_argument('--beta_min', type=float, default=0.1,
5050
help='minimum color variance for each ray')
5151

5252
parser.add_argument('--batch_size', type=int, default=1024,

test_nerf-u.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@
3838
"N_tau = 16\n",
3939
"beta_min = 0.1\n",
4040
"ckpt_path = 'ckpts/lego_nerfw_occ2/epoch=19.ckpt'\n",
41+
"\n",
42+
"N_samples = 64\n",
43+
"N_importance = 64\n",
44+
"use_disp = False\n",
45+
"chunk = 1024*32\n",
4146
"#############################\n",
4247
"\n",
4348
"embedding_xyz = PosEmbedding(9, 10)\n",
@@ -73,11 +78,6 @@
7378
"metadata": {},
7479
"outputs": [],
7580
"source": [
76-
"N_samples = 64\n",
77-
"N_importance = 64\n",
78-
"use_disp = False\n",
79-
"chunk = 1024*32\n",
80-
"\n",
8181
"@torch.no_grad()\n",
8282
"def f(rays, ts):\n",
8383
" \"\"\"Do batched inference on rays using chunk.\"\"\"\n",

0 commit comments

Comments
 (0)