Accueil Intelligence artificielle Posit AI Blog : Entraînez-vous dans R, exécutez sur Android : segmentation d’images avec la torche

Posit AI Blog : Entraînez-vous dans R, exécutez sur Android : segmentation d’images avec la torche

0
Posit AI Blog : Entraînez-vous dans R, exécutez sur Android : segmentation d’images avec la torche


Dans un sens, la segmentation d’images n’est pas si différente de la classification d’images. C’est juste qu’au lieu de catégoriser une image dans son ensemble, la segmentation aboutit à une étiquette pour chaque pixel. Et comme dans la classification d’images, les catégories d’intérêt dépendent de la tâche : premier plan versus arrière-plan, par exemple ; différents types de tissus ; différents types de végétation ; etc.

Le présent article n’est pas le premier sur ce blog à traiter ce sujet ; et comme tous les précédents, il utilise une architecture U-Net pour atteindre son objectif. Les caractéristiques centrales (de ce poste, pas de U-Net) sont :

  1. Il montre comment effectuer une augmentation des données pour une tâche de segmentation d’image.

  2. Il utilise lumière, torchL’interface de haut niveau de, pour entraîner le modèle.

  3. Il Traces JIT le modèle formé et l’enregistre pour le déploiement sur les appareils mobiles. (JIT étant l’acronyme couramment utilisé pour le torch compilateur juste à temps.)

  4. Il comprend un code de preuve de concept (mais pas une discussion) du modèle enregistré exécuté sur Android.

Et si vous pensez que cela n’est pas assez excitant en soi, notre tâche ici est de trouver des chats et des chiens. Quoi de plus utile qu’une application mobile vous permettant de distinguer votre chat du canapé moelleux sur lequel il repose ?

Un chat de l'Oxford Pet Dataset (Parkhi et al. (2012)).

S’entraîner en R

Nous commençons par préparer les données.

Prétraitement et augmentation des données

Tel que fourni par torchdatasetsle Ensemble de données sur les animaux de compagnie d’Oxford est livré avec trois variantes de données cibles parmi lesquelles choisir : la classe globale (chat ou chien), la race individuelle (il y en a trente-sept) et une segmentation au niveau des pixels avec trois catégories : premier plan, limite et arrière-plan. Ce dernier est la valeur par défaut ; et c’est exactement le type de cible dont nous avons besoin.

Un appel à oxford_pet_dataset(root = dir) déclenchera le téléchargement initial :

# need torch > 0.6.1
# may have to run remotes::install_github("mlverse/torch", ref = remotes::github_pull("713")) depending on when you read this
library(torch) 
library(torchvision)
library(torchdatasets)
library(luz)

dir <- "~/.torch-datasets/oxford_pet_dataset"

ds <- oxford_pet_dataset(root = dir)

Les images (et les masques correspondants) sont disponibles en différentes tailles. Cependant, pour l’entraînement, nous aurons besoin qu’ils soient tous de la même taille. Cela peut être accompli en passant transform = et target_transform = arguments. Mais qu’en est-il de l’augmentation des données (en fait, c’est toujours une mesure utile à prendre) ? Imaginez que nous utilisions le retournement aléatoire. Une image d’entrée sera inversée – ou non – selon une certaine probabilité. Mais si l’image est inversée, le masque aurait intérêt à l’être aussi ! Dans ce cas, les transformations d’entrée et de cible ne sont pas indépendantes.

Une solution consiste à créer un wrapper autour oxford_pet_dataset() qui nous permet de « nous accrocher » au .getitem() méthode, comme ceci :

pet_dataset <- torch::dataset(
  
  inherit = oxford_pet_dataset,
  
  initialize = function(..., size, normalize = TRUE, augmentation = NULL) {
    
    self$augmentation <- augmentation
    
    input_transform <- function(x) {
      x <- x %>%
        transform_to_tensor() %>%
        transform_resize(size) 
      # we'll make use of pre-trained MobileNet v2 as a feature extractor
      # => normalize in order to match the distribution of images it was trained with
      if (isTRUE(normalize)) x <- x %>%
        transform_normalize(mean = c(0.485, 0.456, 0.406),
                            std = c(0.229, 0.224, 0.225))
      x
    }
    
    target_transform <- function(x) {
      x <- torch_tensor(x, dtype = torch_long())
      x <- x[newaxis,..]
      # interpolation = 0 makes sure we still end up with integer classes
      x <- transform_resize(x, size, interpolation = 0)
    }
    
    super$initialize(
      ...,
      transform = input_transform,
      target_transform = target_transform
    )
    
  },
  .getitem = function(i) {
    
    item <- super$.getitem(i)
    if (!is.null(self$augmentation)) 
      self$augmentation(item)
    else
      list(x = item$x, y = item$y[1,..])
  }
)

Tout ce que nous avons à faire maintenant est de créer une fonction personnalisée qui nous permet de décider quelle augmentation appliquer à chaque paire entrée-cible, puis d’appeler manuellement les fonctions de transformation respectives.

Ici, nous retournons en moyenne une image sur deux, et si nous le faisons, nous retournons également le masque. La deuxième transformation – orchestrant des changements aléatoires de luminosité, de saturation et de contraste – est appliquée uniquement à l’image d’entrée.

augmentation <- function(item) {
  
  vflip <- runif(1) > 0.5
  
  x <- item$x
  y <- item$y
  
  if (isTRUE(vflip)) {
    x <- transform_vflip(x)
    y <- transform_vflip(y)
  }
  
  x <- transform_color_jitter(x, brightness = 0.5, saturation = 0.3, contrast = 0.3)
  
  list(x = x, y = y[1,..])
  
}

Nous utilisons maintenant le wrapper, pet_dataset()pour instancier les ensembles de formation et de validation, et créer les chargeurs de données respectifs.

train_ds <- pet_dataset(root = dir,
                        split = "train",
                        size = c(224, 224),
                        augmentation = augmentation)
valid_ds <- pet_dataset(root = dir,
                        split = "valid",
                        size = c(224, 224))

train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = 32)

Définition du modèle

Le modèle implémente une architecture U-Net classique, avec une étape de codage (la passe « descendante »), une étape de décodage (la passe « montante ») et, surtout, un « pont » qui transmet les fonctionnalités préservées de l’étape de codage vers couches correspondantes dans l’étape de décodage.

Encodeur

Tout d’abord, nous avons l’encodeur. Il utilise un modèle pré-entraîné (MobileNet v2) comme extracteur de fonctionnalités.

L’encodeur divise les blocs d’extraction de fonctionnalités de MobileNet v2 en plusieurs étapes et applique une étape après l’autre. Les résultats respectifs sont enregistrés dans une liste.

encoder <- nn_module(
  
  initialize = function() {
    model <- model_mobilenet_v2(pretrained = TRUE)
    self$stages <- nn_module_list(list(
      nn_identity(),
      model$features[1:2],
      model$features[3:4],
      model$features[5:7],
      model$features[8:14],
      model$features[15:18]
    ))

    for (par in self$parameters) {
      par$requires_grad_(FALSE)
    }

  },
  forward = function(x) {
    features <- list()
    for (i in 1:length(self$stages)) {
      x <- self$stages[[i]](x)
      features[[length(features) + 1]] <- x
    }
    features
  }
)

Décodeur

Le décodeur est composé de blocs configurables. Un bloc reçoit deux tenseurs d’entrée : l’un qui est le résultat de l’application du bloc de décodeur précédent et l’autre qui contient la carte de caractéristiques produite lors de l’étape d’encodeur correspondant. Lors de la passe directe, le premier est d’abord suréchantillonné et passé par une non-linéarité. Le résultat intermédiaire est ensuite ajouté au deuxième argument, la carte des caractéristiques canalisées. Sur le tenseur résultant, une convolution est appliquée, suivie d’une autre non-linéarité.

decoder_block <- nn_module(
  
  initialize = function(in_channels, skip_channels, out_channels) {
    self$upsample <- nn_conv_transpose2d(
      in_channels = in_channels,
      out_channels = out_channels,
      kernel_size = 2,
      stride = 2
    )
    self$activation <- nn_relu()
    self$conv <- nn_conv2d(
      in_channels = out_channels + skip_channels,
      out_channels = out_channels,
      kernel_size = 3,
      padding = "same"
    )
  },
  forward = function(x, skip) {
    x <- x %>%
      self$upsample() %>%
      self$activation()

    input <- torch_cat(list(x, skip), dim = 2)

    input %>%
      self$conv() %>%
      self$activation()
  }
)

Le décodeur lui-même instancie et parcourt les blocs :

decoder <- nn_module(
  
  initialize = function(
    decoder_channels = c(256, 128, 64, 32, 16),
    encoder_channels = c(16, 24, 32, 96, 320)
  ) {

    encoder_channels <- rev(encoder_channels)
    skip_channels <- c(encoder_channels[-1], 3)
    in_channels <- c(encoder_channels[1], decoder_channels)

    depth <- length(encoder_channels)

    self$blocks <- nn_module_list()
    for (i in seq_len(depth)) {
      self$blocks$append(decoder_block(
        in_channels = in_channels[i],
        skip_channels = skip_channels[i],
        out_channels = decoder_channels[i]
      ))
    }

  },
  forward = function(features) {
    features <- rev(features)
    x <- features[[1]]
    for (i in seq_along(self$blocks)) {
      x <- self$blocks[[i]](x, features[[i+1]])
    }
    x
  }
)

Module de niveau supérieur

Enfin, le module de niveau supérieur génère le score de la classe. Dans notre tâche, il existe trois classes de pixels. Le sous-module de production de partition peut alors n’être qu’une convolution finale, produisant trois canaux :

model <- nn_module(
  
  initialize = function() {
    self$encoder <- encoder()
    self$decoder <- decoder()
    self$output <- nn_sequential(
      nn_conv2d(in_channels = 16,
                out_channels = 3,
                kernel_size = 3,
                padding = "same")
    )
  },
  forward = function(x) {
    x %>%
      self$encoder() %>%
      self$decoder() %>%
      self$output()
  }
)

Formation sur modèle et évaluation (visuelle)

Avec luzla formation du modèle est une affaire de deux verbes, setup() et fit(). Le taux d’apprentissage a été déterminé, pour ce cas précis, à l’aide de luz::lr_finder(); vous devrez probablement le modifier lorsque vous expérimenterez différentes formes d’augmentation des données (et différents ensembles de données).

model <- model %>%
  setup(optimizer = optim_adam, loss = nn_cross_entropy_loss())

fitted <- model %>%
  set_opt_hparams(lr = 1e-3) %>%
  fit(train_dl, epochs = 10, valid_data = valid_dl)

Voici un extrait de l’évolution des performances d’entraînement dans mon cas :

# Epoch 1/10
# Train metrics: Loss: 0.504                                                           
# Valid metrics: Loss: 0.3154

# Epoch 2/10
# Train metrics: Loss: 0.2845                                                           
# Valid metrics: Loss: 0.2549

...
...

# Epoch 9/10
# Train metrics: Loss: 0.1368                                                           
# Valid metrics: Loss: 0.2332

# Epoch 10/10
# Train metrics: Loss: 0.1299                                                           
# Valid metrics: Loss: 0.2511

Les chiffres ne sont que des chiffres : dans quelle mesure le modèle formé est-il vraiment efficace pour segmenter les images d’animaux de compagnie ? Pour le savoir, nous générons des masques de segmentation pour les huit premières observations de l’ensemble de validation et les traçons en superposition sur les images. Un moyen pratique de tracer une image et de superposer un masque est fourni par le raster emballer.

Les intensités des pixels doivent être comprises entre zéro et un, c’est pourquoi, dans le wrapper de l’ensemble de données, nous avons fait en sorte que la normalisation puisse être désactivée. Pour tracer les images réelles, nous instancions simplement un clone de valid_ds cela laisse les valeurs des pixels inchangées. (Les prédictions, en revanche, devront toujours être obtenues à partir de l’ensemble de validation d’origine.)

valid_ds_4plot <- pet_dataset(
  root = dir,
  split = "valid",
  size = c(224, 224),
  normalize = FALSE
)

Enfin, les prédictions sont générées en boucle et superposées aux images une par une :

indices <- 1:8

preds <- predict(fitted, dataloader(dataset_subset(valid_ds, indices)))

png("pet_segmentation.png", width = 1200, height = 600, bg = "black")

par(mfcol = c(2, 4), mar = rep(2, 4))

for (i in indices) {
  
  mask <- as.array(torch_argmax(preds[i,..], 1)$to(device = "cpu"))
  mask <- raster::ratify(raster::raster(mask))
  
  img <- as.array(valid_ds_4plot[i][[1]]$permute(c(2,3,1)))
  cond <- img > 0.99999
  img[cond] <- 0.99999
  img <- raster::brick(img)
  
  # plot image
  raster::plotRGB(img, scale = 1, asp = 1, margins = TRUE)
  # overlay mask
  plot(mask, alpha = 0.4, legend = FALSE, axes = FALSE, add = TRUE)
  
}
Masques de segmentation appris, superposés aux images de l'ensemble de validation.

Passons maintenant à l’exécution de ce modèle « à l’état sauvage » (enfin, en quelque sorte).

Tracer JIT et exécuter sur Android

Le traçage du modèle entraîné le convertira en un formulaire pouvant être chargé dans des environnements sans R, par exemple à partir de Python, C++ ou Java.

Nous accédons au torch modèle sous-jacent à l’ajustement luz objet, et tracez-le – où tracer signifie l’appeler une fois, sur un échantillon d’observation :

m <- fitted$model
x <- coro::collect(train_dl, 1)

traced <- jit_trace(m, x[[1]]$x)

Le modèle tracé peut désormais être enregistré pour être utilisé avec Python ou C++, comme suit :

traced %>% jit_save("traced_model.pt")

Cependant, comme nous savons déjà que nous aimerions le déployer sur Android, nous utilisons plutôt la fonction spécialisée jit_save_for_mobile() qui génère en outre du bytecode :

# need torch > 0.6.1
jit_save_for_mobile(traced_model, "model_bytecode.pt")

Et c’est tout pour la face R !

Pour fonctionner sur Android, j’ai beaucoup utilisé Android de PyTorch Mobile exemples d’applicationsen particulier le segmentation d’images un.

Le code de validation de principe de cet article (qui a été utilisé pour générer l’image ci-dessous) peut être trouvé ici : https://github.com/skeydan/ImageSegmentation. (Attention cependant : c’est ma première application Android !).

Bien sûr, nous devons encore essayer de retrouver le chat. Voici le modèle, exécuté sur un émulateur d’appareil dans Android Studio, sur trois images (de l’Oxford Pet Dataset) sélectionnées, d’une part, pour leur large gamme de difficulté, et d’autre part, eh bien… pour leur gentillesse :

Où est mon chat ?

Merci d’avoir lu!

Parkhi, Omkar M., Andrea Vedaldi, Andrew Zisserman et CV Jawahar. 2012. « Les chats et les chiens. » Dans Conférence IEEE sur la vision par ordinateur et la reconnaissance de formes.

Ronneberger, Olaf, Philipp Fischer et Thomas Brox. 2015. «U-Net : réseaux convolutifs pour la segmentation d’images biomédicales. » CoRR abs/1505.04597. http://arxiv.org/abs/1505.04597.

LAISSER UN COMMENTAIRE

S'il vous plaît entrez votre commentaire!
S'il vous plaît entrez votre nom ici