Nous sommes heureux d’annoncer que la version 0.2.0 de torch
vient d’atterrir sur CRAN.
Cette version inclut de nombreuses corrections de bugs et quelques nouvelles fonctionnalités intéressantes que nous présenterons dans cet article de blog. Vous pouvez voir le journal des modifications complet dans le NOUVELLES.md déposer.
Les fonctionnalités dont nous discuterons en détail sont :
- Prise en charge initiale du traçage JIT
- Chargeurs de données multi-travailleurs
- Méthodes d’impression pour
nn_modules
Chargeurs de données multi-travailleurs
dataloaders
répondez maintenant au num_workers
argument et exécutera le prétraitement dans les travailleurs parallèles.
Par exemple, disons que nous avons l’ensemble de données factice suivant qui effectue un long calcul :
library(torch)
dat <- dataset(
"mydataset",
initialize = function(time, len = 10) {
self$time <- time
self$len <- len
},
.getitem = function(i) {
Sys.sleep(self$time)
torch_randn(1)
},
.length = function() {
self$len
}
)
ds <- dat(1)
system.time(ds[1])
user system elapsed
0.029 0.005 1.027
Nous allons maintenant créer deux chargeurs de données, un qui s’exécute séquentiellement et un autre qui s’exécute en parallèle.
seq_dl <- dataloader(ds, batch_size = 5)
par_dl <- dataloader(ds, batch_size = 5, num_workers = 2)
Nous pouvons désormais comparer le temps nécessaire pour traiter deux lots séquentiellement au temps nécessaire pour traiter en parallèle :
seq_it <- dataloader_make_iter(seq_dl)
par_it <- dataloader_make_iter(par_dl)
two_batches <- function(it) {
dataloader_next(it)
dataloader_next(it)
"ok"
}
system.time(two_batches(seq_it))
system.time(two_batches(par_it))
user system elapsed
0.098 0.032 10.086
user system elapsed
0.065 0.008 5.134
Notez que ce sont des lots qui sont obtenus en parallèle et non des observations individuelles. Ainsi, nous pourrons à l’avenir prendre en charge des ensembles de données avec des tailles de lots variables.
Utiliser plusieurs travailleurs est pas nécessairement plus rapide que l’exécution en série car il y a une surcharge considérable lors du passage des tenseurs d’un travailleur à la session principale ainsi que lors de l’initialisation des travailleurs.
Cette fonctionnalité est activée par le puissant callr
package et fonctionne dans tous les systèmes d’exploitation pris en charge par torch
. callr
créons des sessions R persistantes, et ainsi, nous ne payons qu’une seule fois les frais généraux liés au transfert d’objets de jeux de données potentiellement volumineux vers les travailleurs.
Lors du processus d’implémentation de cette fonctionnalité, nous avons fait en sorte que les chargeurs de données se comportent comme coro
itérateurs. Cela signifie que vous pouvez désormais utiliser coro
La syntaxe de pour parcourir les chargeurs de données :
coro::loop(for(batch in par_dl) {
print(batch$shape)
})
[1] 5 1
[1] 5 1
C’est le premier torch
version incluant la fonctionnalité de chargeurs de données multi-travailleurs, et vous pourriez rencontrer des cas extrêmes lors de son utilisation. Faites-nous savoir si vous rencontrez des problèmes.
Prise en charge initiale du JIT
Les programmes qui utilisent le torch
sont inévitablement des programmes R et, par conséquent, ils ont toujours besoin d’une installation R pour s’exécuter.
Depuis la version 0.2.0, torch
permet aux utilisateurs de JIT tracer
torch
R fonctionne dans TorchScript. Le traçage JIT (Juste à temps) invoquera une fonction R avec des exemples d’entrées, enregistrera toutes les opérations survenues lors de l’exécution de la fonction et renverra un script_function
objet contenant la représentation TorchScript.
La bonne nouvelle est que les programmes TorchScript sont facilement sérialisables, optimisables et peuvent être chargés par un autre programme écrit en PyTorch ou LibTorch sans nécessiter aucune dépendance R.
Supposons que vous ayez la fonction R suivante qui prend un tenseur, effectue une multiplication matricielle avec une matrice à poids fixe, puis ajoute un terme de biais :
w <- torch_randn(10, 1)
b <- torch_randn(1)
fn <- function(x) {
a <- torch_mm(x, w)
a + b
}
Cette fonction peut être tracée JIT dans TorchScript avec jit_trace
en passant la fonction et les exemples d’entrées :
x <- torch_ones(2, 10)
tr_fn <- jit_trace(fn, x)
tr_fn(x)
torch_tensor
-0.6880
-0.6880
[ CPUFloatType{2,1} ]
Maintenant tout torch
les opérations survenues lors du calcul du résultat de cette fonction ont été tracées et transformées en graphique :
graph(%0 : Float(2:10, 10:1, requires_grad=0, device=cpu)):
%1 : Float(10:1, 1:1, requires_grad=0, device=cpu) = prim::Constant[value=-0.3532 0.6490 -0.9255 0.9452 -1.2844 0.3011 0.4590 -0.2026 -1.2983 1.5800 [ CPUFloatType{10,1} ]]()
%2 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::mm(%0, %1)
%3 : Float(1:1, requires_grad=0, device=cpu) = prim::Constant[value={-0.558343}]()
%4 : int = prim::Constant[value=1]()
%5 : Float(2:1, 1:1, requires_grad=0, device=cpu) = aten::add(%2, %3, %4)
return (%5)
La fonction tracée peut être sérialisée avec jit_save
:
jit_save(tr_fn, "linear.pt")
Il peut être rechargé dans R avec jit_load
mais il peut aussi être rechargé en Python avec torch.jit.load
:
import torch
= torch.jit.load("linear.pt")
fn 2, 10)) fn(torch.ones(
tensor([[-0.6880],
[-0.6880]])
À quel point cela est cool?!
Il ne s’agit que du support initial de JIT dans R. Nous continuerons à le développer. Plus précisément, dans la prochaine version de torch
nous prévoyons de soutenir le traçage nn_modules
directement. Actuellement, vous devez détacher tous les paramètres avant de les tracer ; voir un exemple ici. Cela vous permettra également de profiter de TorchScript pour rendre vos modèles plus rapides !
Notez également que le traçage présente certaines limites, en particulier lorsque votre code comporte des boucles ou des instructions de flux de contrôle qui dépendent des données tensorielles. Voir ?jit_trace
pour apprendre plus.
Nouvelle méthode d’impression pour nn_modules
Dans cette version, nous avons également amélioré le nn_module
méthodes d’impression afin de faciliter la compréhension de ce qu’il y a à l’intérieur.
Par exemple, si vous créez une instance d’un nn_linear
module, vous verrez :
An `nn_module` containing 11 parameters.
── Parameters ──────────────────────────────────────────────────────────────────
● weight: Float [1:1, 1:10]
● bias: Float [1:1]
Vous voyez immédiatement le nombre total de paramètres dans le module ainsi que leurs noms et formes.
Cela fonctionne également pour les modules personnalisés (y compris éventuellement des sous-modules). Par exemple:
my_module <- nn_module(
initialize = function() {
self$linear <- nn_linear(10, 1)
self$param <- nn_parameter(torch_randn(5,1))
self$buff <- nn_buffer(torch_randn(5))
}
)
my_module()
An `nn_module` containing 16 parameters.
── Modules ─────────────────────────────────────────────────────────────────────
● linear: <nn_linear> #11 parameters
── Parameters ──────────────────────────────────────────────────────────────────
● param: Float [1:5, 1:1]
── Buffers ─────────────────────────────────────────────────────────────────────
● buff: Float [1:5]
Nous espérons que cela facilitera la compréhension nn_module
objets. Nous avons également amélioré la prise en charge de la saisie semi-automatique pour nn_modules
et nous allons maintenant afficher tous les sous-modules, paramètres et tampons pendant que vous tapez.
torcheaudio
torchaudio
est une extension pour torch
développé par Athos Damiani (@athospd
), fournissant un chargement audio, des transformations, des architectures communes pour le traitement du signal, des poids pré-entraînés et un accès aux ensembles de données couramment utilisés. Une traduction presque littérale de la bibliothèque Torchaudio de PyTorch vers R.
torchaudio
n’est pas encore sur CRAN, mais vous pouvez déjà essayer la version de développement disponible ici.
Vous pouvez également visiter le pkgdown
site web pour des exemples et de la documentation de référence.
Autres fonctionnalités et corrections de bugs
Grâce aux contributions de la communauté, nous avons trouvé et corrigé de nombreux bugs dans torch
. Nous avons également ajouté de nouvelles fonctionnalités, notamment :
Vous pouvez voir la liste complète des changements dans le NOUVELLES.md déposer.
Merci beaucoup d’avoir lu cet article de blog et n’hésitez pas à nous contacter sur GitHub pour obtenir de l’aide ou des discussions !
La photo utilisée dans cet aperçu de l’article est de Oleg Illarionov sur Unsplash