Move the model to gpu

main
Gašper Spagnolo 2023-08-01 14:58:24 +02:00
parent f1e06e97ad
commit acd5250c12
No known key found for this signature in database
GPG Key ID: 2EA0738CC1EFEEB7
1 changed files with 2 additions and 0 deletions

2
ft.py
View File

@ -21,6 +21,8 @@ class FT:
self.model_name self.model_name
) )
self.model.to(self.device)
# set up optimizer # set up optimizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1e-5)
self.scheduler = ReduceLROnPlateau( self.scheduler = ReduceLROnPlateau(