mardi, novembre 28, 2023

Blog Posit AI : Implémentation de l’équivariance de rotation : CNN équivariant de groupe à partir de zéro


Les réseaux de neurones convolutifs (CNN) sont formidables : ils sont capables de détecter les caractéristiques d’une image, où qu’ils soient. Eh bien, pas exactement. Ils ne sont pas indifférents à n’importe quel type de mouvement. Se déplacer vers le haut ou le bas, ou vers la gauche ou la droite, est très bien ; tourner autour d’un axe ne l’est pas. Cela est dû au fonctionnement de la convolution : parcours par ligne, puis parcours par colonne (ou inversement). Si nous voulons « plus » (par exemple, la détection réussie d’un objet à l’envers), nous devons étendre la convolution à une opération qui est équivariant de rotation. Une opération qui est équivariant à un certain type d’action enregistrera non seulement l’entité déplacée en soi, mais gardera également une trace de l’action concrète qui l’a fait apparaître là où elle se trouve.

Ceci est le deuxième article d’une série qui présente les CNN équivariants de groupe (GCNN). Le d’abord était une introduction de haut niveau expliquant pourquoi nous les voudrions et comment ils fonctionnent. Nous y avons présenté l’acteur clé, le groupe de symétrie, qui spécifie quels types de transformations doivent être traités de manière équivariante. Si ce n’est pas le cas, jetez d’abord un œil à cet article, car j’utiliserai ici la terminologie et les concepts qu’il a introduits.

Aujourd’hui, nous codons un simple GCNN à partir de zéro. Le code et la présentation suivent étroitement un carnet de notes dispensé dans le cadre du programme 2022 de l’Université d’Amsterdam Cours d’apprentissage profond. On ne saurait assez les remercier d’avoir mis à disposition d’aussi excellents matériels d’apprentissage.

Dans ce qui suit, mon intention est d’expliquer la pensée générale et la manière dont l’architecture résultante est construite à partir de modules plus petits, chacun ayant un objectif clair. Pour cette raison, je ne reproduirai pas tout le code ici ; à la place, j’utiliserai le package gcnn. Ses méthodes sont fortement annotées ; alors pour voir quelques détails, n’hésitez pas à regarder le code.

À ce jour, gcnn implémente un groupe de symétrie : \(C_4\), celui qui sert d’exemple courant tout au long du premier article. Il est cependant directement extensible, utilisant des hiérarchies de classes partout.

Étape 1 : Le groupe de symétrie \(C_4\)

Lors du codage d’un GCNN, la première chose que nous devons fournir est une implémentation du groupe de symétrie que nous aimerions utiliser. C’est ici \(C_4\)le groupe de quatre éléments qui tourne de 90 degrés.

Nous pouvons demander gcnn pour en créer un pour nous et inspecter ses éléments.

# remotes::install_github("skeydan/gcnn")
library(gcnn)
library(torch)

C_4 <- CyclicGroup(order = 4)
elems <- C_4$elements()
elems
torch_tensor
 0.0000
 1.5708
 3.1416
 4.7124
[ CPUFloatType{4} ]

Les éléments sont représentés par leurs angles de rotation respectifs : \(0\), \(\frac{\pi}{2}\), \(\pi\)et \(\frac{3 \pi}{2}\).

Les groupes sont conscients de l’identité et savent construire l’inverse d’un élément :

C_4$identity

g1 <- elems[2]
C_4$inverse(g1)
torch_tensor
 0
[ CPUFloatType{1} ]

torch_tensor
4.71239
[ CPUFloatType{} ]

Ici, ce qui nous intéresse le plus, ce sont les éléments du groupe. action. Du point de vue de la mise en œuvre, il faut distinguer entre leurs actions les unes sur les autres et leur action sur l’espace vectoriel. \(\mathbb{R}^2\), où vivent nos images d’entrée. La première partie est la plus simple : elle peut simplement être mise en œuvre en ajoutant des angles. En fait, c’est ce que gcnn fait quand on lui demande de laisser g1 agir sur g2:

g2 <- elems[3]

# in C_4$left_action_on_H(), H stands for the symmetry group
C_4$left_action_on_H(torch_tensor(g1)$unsqueeze(1), torch_tensor(g2)$unsqueeze(1))
torch_tensor
 4.7124
[ CPUFloatType{1,1} ]

C’est quoi ce unsqueeze()s? Depuis \(C_4\)c’est l’ultime raison d’être c’est faire partie d’un réseau de neurones, left_action_on_H() fonctionne avec des lots d’éléments, pas avec des tenseurs scalaires.

Les choses sont un peu moins simples lorsque l’action de groupe sur \(\mathbb{R}^2\) est concerné. Ici, nous avons besoin du concept de représentation de groupe. C’est un sujet complexe que nous n’aborderons pas ici. Dans notre contexte actuel, cela fonctionne comme ceci : nous avons un signal d’entrée, un tenseur sur lequel nous aimerions opérer d’une manière ou d’une autre. (Cette « d’une manière ou d’une autre » sera une convolution, comme nous le verrons bientôt.) Pour rendre cette opération équivariante au groupe, nous demandons d’abord à la représentation d’appliquer le inverse action de groupe à l’entrée. Cela fait, nous continuons l’opération comme si de rien n’était.

Pour donner un exemple concret, disons que l’opération est une mesure. Imaginez un coureur, debout au pied d’un sentier de montagne, prêt à gravir l’ascension. Nous aimerions enregistrer leur taille. Une option que nous avons est de prendre la mesure, puis de la laisser courir. Nos mesures seront aussi valables en haut de la montagne qu’elles l’étaient ici. Alternativement, nous pourrions être polis et ne pas les faire attendre. Une fois qu’ils sont là-haut, nous leur demandons de redescendre, et lorsqu’ils reviennent, nous mesurons leur taille. Le résultat est le même : la taille du corps est équivariante (plus que cela : invariante, même) à l’action de monter ou de descendre. (Bien sûr, la taille est une mesure assez ennuyeuse. Mais quelque chose de plus intéressant, comme la fréquence cardiaque, n’aurait pas aussi bien fonctionné dans cet exemple.)

Pour en revenir à l’implémentation, il s’avère que les actions de groupe sont codées sous forme de matrices. Il existe une matrice pour chaque élément du groupe. Pour \(C_4\)la dite standard la représentation est une matrice de rotation :

\[
\begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix}
\]

Dans gcnnla fonction appliquant cette matrice est left_action_on_R2(). Comme son frère, il est conçu pour fonctionner avec des lots (d’éléments de groupe ainsi que \(\mathbb{R}^2\) vecteurs). Techniquement, il fait pivoter la grille sur laquelle l’image est définie, puis ré-échantillonne l’image. Pour rendre cela plus concret, le code de cette méthode se présente comme suit.

Voici une chèvre.

img_path <- system.file("imgs", "z.jpg", package = "gcnn")
img <- torchvision::base_loader(img_path) |> torchvision::transform_to_tensor()
img$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot()

Une chèvre assise confortablement sur un pré.

Tout d’abord, nous appelons C_4$left_action_on_R2() pour faire pivoter la grille.

# Grid shape is [2, 1024, 1024], for a 2d, 1024 x 1024 image.
img_grid_R2 <- torch::torch_stack(torch::torch_meshgrid(
    list(
      torch::torch_linspace(-1, 1, dim(img)[2]),
      torch::torch_linspace(-1, 1, dim(img)[3])
    )
))

# Transform the image grid with the matrix representation of some group element.
transformed_grid <- C_4$left_action_on_R2(C_4$inverse(g1)$unsqueeze(1), img_grid_R2)

Deuxièmement, nous ré-échantillonnons l’image sur la grille transformée. La chèvre lève maintenant les yeux vers le ciel.

transformed_img <- torch::nnf_grid_sample(
  img$unsqueeze(1), transformed_grid,
  align_corners = TRUE, mode = "bilinear", padding_mode = "zeros"
)

transformed_img[1,..]$permute(c(2, 3, 1)) |> as.array() |> as.raster() |> plot()

Même chèvre, tournée de 90 degrés vers le haut.

Étape 2 : La convolution de levage

Nous souhaitons utiliser les outils existants et efficaces torch fonctionnalité autant que possible. Concrètement, nous voulons utiliser nn_conv2d(). Ce dont nous avons besoin, cependant, c’est d’un noyau de convolution qui soit équivariant non seulement à la traduction, mais aussi à l’action de \(C_4\). Ceci peut être réalisé en ayant un noyau pour chaque rotation possible.

Mettre en œuvre cette idée est exactement ce que LiftingConvolution fait. Le principe est le même que précédemment : tout d’abord, la grille est pivotée, puis le noyau (matrice de poids) est rééchantillonné dans la grille transformée.

Pourquoi, cependant, appeler cela un convolution de levage? Le noyau de convolution habituel fonctionne sur \(\mathbb{R}^2\); tandis que notre version étendue fonctionne sur des combinaisons de \(\mathbb{R}^2\) et \(C_4\). En mathématiques, cela a été levé au produit semi-direct \(\mathbb{R}^2\rtimes C_4\).

lifting_conv <- LiftingConvolution(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 3,
    out_channels = 8
  )

x <- torch::torch_randn(c(2, 3, 32, 32))
y <- lifting_conv(x)
y$shape
[1]  2  8  4 28 28

Puisque, intérieurement, LiftingConvolution utilise une dimension supplémentaire pour réaliser le produit des traductions et des rotations, le résultat n’est pas à quatre, mais à cinq dimensions.

Étape 3 : Convolutions de groupe

Maintenant que nous sommes dans un « espace étendu de groupe », nous pouvons enchaîner un certain nombre de couches où les entrées et les sorties sont convolution de groupe couches. Par exemple:

group_conv <- GroupConvolution(
  group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 8,
    out_channels = 16
)

z <- group_conv(y)
z$shape
[1]  2 16  4 24 24

Il ne reste plus qu’à emballer tout cela. C’est ce que gcnn::GroupEquivariantCNN() fait.

Étape 4 : CNN équivalent à un groupe

Nous pouvons appeler GroupEquivariantCNN() ainsi.

cnn <- GroupEquivariantCNN(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 1,
    num_hidden = 2, # number of group convolutions
    hidden_channels = 16 # number of channels per group conv layer
)

img <- torch::torch_randn(c(4, 1, 32, 32))
cnn(img)$shape
[1] 4 1

D’un simple coup d’oeil, ceci GroupEquivariantCNN on dirait n’importe quel vieux CNN… n’était-ce pas le group argument.

Désormais, lorsque nous inspectons sa sortie, nous constatons que la dimension supplémentaire a disparu. En effet, après une séquence de couches de convolution de groupe à groupe, le module se projette vers une représentation qui, pour chaque élément du lot, conserve uniquement les canaux. La moyenne n’est donc pas seulement basée sur les emplacements – comme nous le faisons habituellement – ​​mais également sur la dimension du groupe. Une couche linéaire finale fournira alors la sortie du classificateur demandée (de dimension out_channels).

Et là nous avons l’architecture complète. Il est temps de passer à un monde réel (ouais) test.

Chiffres pivotés !

L’idée est de former deux convnets, un CNN « normal » et un équivalent de groupe, sur l’ensemble de formation habituel du MNIST. Ensuite, les deux sont évalués sur un ensemble de tests augmentés où chaque image subit une rotation aléatoire selon une rotation continue entre 0 et 360 degrés. Nous ne nous attendons pas GroupEquivariantCNN être « parfait » – pas si nous nous équipons de \(C_4\) comme groupe de symétrie. Strictement, avec \(C_4\), l’équivariance s’étend sur quatre positions seulement. Mais nous espérons qu’elle fonctionnera nettement mieux que l’architecture standard à équivariant uniquement.

Tout d’abord, nous préparons les données ; en particulier, l’ensemble de tests augmenté.

dir <- "/tmp/mnist"

train_ds <- torchvision::mnist_dataset(
  dir,
  download = TRUE,
  transform = torchvision::transform_to_tensor
)

test_ds <- torchvision::mnist_dataset(
  dir,
  train = FALSE,
  transform = function(x) {
    x |>
      torchvision::transform_to_tensor() |>
      torchvision::transform_random_rotation(
        degrees = c(0, 360),
        resample = 2,
        fill = 0
      )
  }
)

train_dl <- dataloader(train_ds, batch_size = 128, shuffle = TRUE)
test_dl <- dataloader(test_ds, batch_size = 128)

De quoi ça a l’air?

test_images <- coro::collect(
  test_dl, 1
)[[1]]$x[1:32, 1, , ] |> as.array()

par(mfrow = c(4, 8), mar = rep(0, 4), mai = rep(0, 4))
test_images |>
  purrr::array_tree(1) |>
  purrr::map(as.raster) |>
  purrr::iwalk(~ {
    plot(.x)
  })

32 chiffres, tournés de manière aléatoire.

Nous définissons et formons d’abord un CNN conventionnel. C’est aussi semblable à GroupEquivariantCNN()du point de vue de l’architecture, autant que possible, et dispose de deux fois plus de canaux cachés, afin d’avoir une capacité globale comparable.

 default_cnn <- nn_module(
   "default_cnn",
   initialize = function(kernel_size, in_channels, out_channels, num_hidden, hidden_channels) {
     self$conv1 <- torch::nn_conv2d(in_channels, hidden_channels, kernel_size)
     self$convs <- torch::nn_module_list()
     for (i in 1:num_hidden) {
       self$convs$append(torch::nn_conv2d(hidden_channels, hidden_channels, kernel_size))
     }
     self$avg_pool <- torch::nn_adaptive_avg_pool2d(1)
     self$final_linear <- torch::nn_linear(hidden_channels, out_channels)
   },
   forward = function(x) {
     x <- x |>
       self$conv1() |>
       (\(.) torch::nnf_layer_norm(., .$shape[2:4]))() |>
       torch::nnf_relu()
     for (i in 1:(length(self$convs))) {
       x <- x |>
         self$convs[[i]]() |>
         (\(.) torch::nnf_layer_norm(., .$shape[2:4]))() |>
         torch::nnf_relu()
     }
     x <- x |>
       self$avg_pool() |>
       torch::torch_squeeze() |>
       self$final_linear()
     x
   }
 )

fitted <- default_cnn |>
    luz::setup(
      loss = torch::nn_cross_entropy_loss(),
      optimizer = torch::optim_adam,
      metrics = list(
        luz::luz_metric_accuracy()
      )
    ) |>
    luz::set_hparams(
      kernel_size = 5,
      in_channels = 1,
      out_channels = 10,
      num_hidden = 4,
      hidden_channels = 32
    ) %>%
    luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
    luz::fit(train_dl, epochs = 10, valid_data = test_dl) 
Train metrics: Loss: 0.0498 - Acc: 0.9843
Valid metrics: Loss: 3.2445 - Acc: 0.4479

Sans surprise, la précision sur l’ensemble de test n’est pas très bonne.

Ensuite, nous formons la version équivariante de groupe.

fitted <- GroupEquivariantCNN |>
  luz::setup(
    loss = torch::nn_cross_entropy_loss(),
    optimizer = torch::optim_adam,
    metrics = list(
      luz::luz_metric_accuracy()
    )
  ) |>
  luz::set_hparams(
    group = CyclicGroup(order = 4),
    kernel_size = 5,
    in_channels = 1,
    out_channels = 10,
    num_hidden = 4,
    hidden_channels = 16
  ) |>
  luz::set_opt_hparams(lr = 1e-2, weight_decay = 1e-4) |>
  luz::fit(train_dl, epochs = 10, valid_data = test_dl)
Train metrics: Loss: 0.1102 - Acc: 0.9667
Valid metrics: Loss: 0.4969 - Acc: 0.8549

Pour CNN équivariant au groupe, les précisions sur les ensembles de tests et de formation sont beaucoup plus proches. C’est un joli résultat ! Terminons l’exploit d’aujourd’hui en reprenant une réflexion du premier message de plus haut niveau.

Un défi

En revenant à l’ensemble de tests augmenté, ou plutôt aux échantillons de chiffres affichés, nous remarquons un problème. Dans la deuxième ligne, colonne quatre, il y a un chiffre qui, « dans des circonstances normales », devrait être un 9, mais, très probablement, il s’agit d’un 6 à l’envers. (Pour un humain, ce qui suggère que c’est la chose semblable à un gribouillis qui semble se trouver plus souvent avec les six qu’avec les neuf.) Cependant, vous pourriez vous demander : est-ce que cela avoir être un problème ? Peut-être que le réseau a simplement besoin d’apprendre les subtilités, le genre de choses qu’un humain repérerait ?

Selon moi, tout dépend du contexte : ce qui doit réellement être accompli et comment une application va être utilisée. Avec des chiffres sur une lettre, je ne vois aucune raison pour qu’un seul chiffre apparaisse à l’envers ; par conséquent, l’équivariance de rotation complète serait contre-productive. En un mot, nous arrivons au même impératif canonique que les partisans d’un apprentissage automatique juste et équitable ne cessent de nous rappeler :

Pensez toujours à la manière dont une application va être utilisée !

Mais dans notre cas, il y a un autre aspect à cela, d’ordre technique. gcnn::GroupEquivariantCNN() est un simple wrapper, dans la mesure où ses couches utilisent toutes le même groupe de symétrie. En principe, cela n’est pas nécessaire. Avec plus d’effort de codage, différents groupes peuvent être utilisés en fonction de la position d’une couche dans la hiérarchie de détection de fonctionnalités.

Ici, laissez-moi enfin vous dire pourquoi j’ai choisi la photo de la chèvre. La chèvre est vue à travers une clôture rouge et blanche, un motif – légèrement tourné, en raison de l’angle de vue – composé de carrés (ou de bords, si vous préférez). Or, pour une telle clôture, les types d’équivariance de rotation tels que celui codé par \(C_4\) ont beaucoup de sens. La chèvre elle-même, cependant, nous préférerions ne pas regarder vers le ciel, comme je l’ai illustré \(C_4\) action avant. Ainsi, dans une tâche de classification d’images réelle, nous utiliserions des couches plutôt flexibles en bas et des couches de plus en plus restreintes en haut de la hiérarchie.

Merci d’avoir lu!

photo par Marjan Blan | @marjanblan sur Unsplash

Related Articles

LAISSER UN COMMENTAIRE

S'il vous plaît entrez votre commentaire!
S'il vous plaît entrez votre nom ici

Latest Articles