diff --git a/demo.py b/demo.py index 9fadc75..4489ff2 100644 --- a/demo.py +++ b/demo.py @@ -100,11 +100,15 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text # load model weights model = model_static(model_weight) model_dict = model.state_dict() - snapshot = torch.load(model_weight) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + snapshot = torch.load(model_weight, map_location=device) model_dict.update(snapshot) model.load_state_dict(model_dict) - model.cuda() + # Move the model to GPU if available + model.to(device) + + model.train(False) # video reading loop @@ -148,7 +152,8 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text img = torch.cat([img, img_jittered]) # forward pass - output = model(img.cuda()) + img = img.to(device) # Ensure the image tensor is on the correct device + output = model(img) if jitter > 0: output = torch.mean(output, 0) score = F.sigmoid(output).item() @@ -177,7 +182,7 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text if save_text: f.close() cap.release() - print 'DONE!' + print('DONE!') if __name__ == "__main__":