diff --git a/ft.py b/ft.py index 6b85b9e..1a7a8fe 100644 --- a/ft.py +++ b/ft.py @@ -21,6 +21,8 @@ class FT: self.model_name ) + self.model.to(self.device) + # set up optimizer self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5) self.scheduler = ReduceLROnPlateau(