@@ -265,50 +265,50 @@ def test_pretraining_tagger():
265
265
266
266
267
267
# Try to debug segfault on windows
268
- # def test_pretraining_training():
269
- # """Test that training can use a pretrained Tok2Vec model"""
270
- # config = Config().from_str(pretrain_string_internal)
271
- # nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
272
- # filled = nlp.config
273
- # pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
274
- # filled = pretrain_config.merge(filled)
275
- # train_config = util.load_config(DEFAULT_CONFIG_PATH)
276
- # filled = train_config.merge(filled)
277
- # with make_tempdir() as tmp_dir:
278
- # pretrain_dir = tmp_dir / "pretrain"
279
- # pretrain_dir.mkdir()
280
- # file_path = write_sample_jsonl(pretrain_dir)
281
- # filled["paths"]["raw_text"] = file_path
282
- # filled["pretraining"]["component"] = "tagger"
283
- # filled["pretraining"]["layer"] = "tok2vec"
284
- # train_dir = tmp_dir / "train"
285
- # train_dir.mkdir()
286
- # train_path, dev_path = write_sample_training(train_dir)
287
- # filled["paths"]["train"] = train_path
288
- # filled["paths"]["dev"] = dev_path
289
- # filled = filled.interpolate()
290
- # P = filled["pretraining"]
291
- # nlp_base = init_nlp(filled)
292
- # model_base = (
293
- # nlp_base.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed")
294
- # )
295
- # embed_base = None
296
- # for node in model_base.walk():
297
- # if node.name == "hashembed":
298
- # embed_base = node
299
- # pretrain(filled, pretrain_dir)
300
- # pretrained_model = Path(pretrain_dir / "model3.bin")
301
- # assert pretrained_model.exists()
302
- # filled["initialize"]["init_tok2vec"] = str(pretrained_model)
303
- # nlp = init_nlp(filled)
304
- # model = nlp.get_pipe(P["component"]).model.get_ref(P["layer"]).get_ref("embed")
305
- # embed = None
306
- # for node in model.walk():
307
- # if node.name == "hashembed":
308
- # embed = node
309
- # # ensure that the tok2vec weights are actually changed by the pretraining
310
- # assert np.any(np.not_equal(embed.get_param("E"), embed_base.get_param("E")))
311
- # train(nlp, train_dir)
268
+ def test_pretraining_training ():
269
+ """Test that training can use a pretrained Tok2Vec model"""
270
+ config = Config ().from_str (pretrain_string_internal )
271
+ nlp = util .load_model_from_config (config , auto_fill = True , validate = False )
272
+ filled = nlp .config
273
+ pretrain_config = util .load_config (DEFAULT_CONFIG_PRETRAIN_PATH )
274
+ filled = pretrain_config .merge (filled )
275
+ train_config = util .load_config (DEFAULT_CONFIG_PATH )
276
+ filled = train_config .merge (filled )
277
+ with make_tempdir () as tmp_dir :
278
+ pretrain_dir = tmp_dir / "pretrain"
279
+ pretrain_dir .mkdir ()
280
+ file_path = write_sample_jsonl (pretrain_dir )
281
+ filled ["paths" ]["raw_text" ] = file_path
282
+ filled ["pretraining" ]["component" ] = "tagger"
283
+ filled ["pretraining" ]["layer" ] = "tok2vec"
284
+ train_dir = tmp_dir / "train"
285
+ train_dir .mkdir ()
286
+ train_path , dev_path = write_sample_training (train_dir )
287
+ filled ["paths" ]["train" ] = train_path
288
+ filled ["paths" ]["dev" ] = dev_path
289
+ filled = filled .interpolate ()
290
+ P = filled ["pretraining" ]
291
+ nlp_base = init_nlp (filled )
292
+ model_base = (
293
+ nlp_base .get_pipe (P ["component" ]).model .get_ref (P ["layer" ]).get_ref ("embed" )
294
+ )
295
+ embed_base = None
296
+ for node in model_base .walk ():
297
+ if node .name == "hashembed" :
298
+ embed_base = node
299
+ pretrain (filled , pretrain_dir )
300
+ pretrained_model = Path (pretrain_dir / "model3.bin" )
301
+ assert pretrained_model .exists ()
302
+ filled ["initialize" ]["init_tok2vec" ] = str (pretrained_model )
303
+ nlp = init_nlp (filled )
304
+ model = nlp .get_pipe (P ["component" ]).model .get_ref (P ["layer" ]).get_ref ("embed" )
305
+ embed = None
306
+ for node in model .walk ():
307
+ if node .name == "hashembed" :
308
+ embed = node
309
+ # ensure that the tok2vec weights are actually changed by the pretraining
310
+ assert np .any (np .not_equal (embed .get_param ("E" ), embed_base .get_param ("E" )))
311
+ train (nlp , train_dir )
312
312
313
313
314
314
def write_sample_jsonl (tmp_dir ):
0 commit comments