From 4ad8e9b763c8205f1203f056d57d0bcb0f880166 Mon Sep 17 00:00:00 2001 From: Matthias Norden <59486396+mno-93@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:55:37 +0200 Subject: [PATCH] Update demo.py to use CPU if CUDA not available --- demo.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/demo.py b/demo.py index 9fadc75..4489ff2 100644 --- a/demo.py +++ b/demo.py @@ -100,11 +100,15 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text # load model weights model = model_static(model_weight) model_dict = model.state_dict() - snapshot = torch.load(model_weight) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + snapshot = torch.load(model_weight, map_location=device) model_dict.update(snapshot) model.load_state_dict(model_dict) - model.cuda() + # Move the model to GPU if available + model.to(device) + + model.train(False) # video reading loop @@ -148,7 +152,8 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text img = torch.cat([img, img_jittered]) # forward pass - output = model(img.cuda()) + img = img.to(device) # Ensure the image tensor is on the correct device + output = model(img) if jitter > 0: output = torch.mean(output, 0) score = F.sigmoid(output).item() @@ -177,7 +182,7 @@ def run(video_path, face_path, model_weight, jitter, vis, display_off, save_text if save_text: f.close() cap.release() - print 'DONE!' + print('DONE!') if __name__ == "__main__":