Рассмотрим следующие операции:

  1. обучение модели;
  2. сохранение обученной модели в базе данных;
  3. извлечение модели из базы для работы в R.

Первый пункт никак не связан с базами данных и выполняется средствами caret.

Обучение модели

В качестве примера рассмотрим обучение модели классификации с помощью линейного дискриминантного анализа на популярном наборе данных iris. Это может быть любой другой метод классификации или набор данных -- нам сейчас важен принцип.

library(caret)

# Загружаем данные.

data <- iris

# Разделяем данные на обучающую и тестовую выборки.

set.seed(1234)

trainIndex <- createDataPartition(1:nrow(data), times = 1, p = .8)[знаем](1]]

trainSet  <- data[trainIndex,]
testSet   <- data[-trainIndex,]

# Управление обучением

fitCtrl <- trainControl(method = "repeatedcv",  # Кросс-валидация
                        number = 10,            # данные разбиваются на 10 частей
                        repeats = 5             # число повторений
)

# Обучение классификации по методу линейного дискриминантного анализа

model <- train(trainSet[,-5],trainSet[,5],
               method = 'lda',
               trControl = fitCtrl
)

# Обученная модель, которую нужно сохранить:

fit <- model$finalModel

Обученная модель fit представляет собой список, а как сохранять в базе данных списки R мы уже знаем.

Сохранение модели в базе данных

Как и в случае простого списка нам нужно сериализовать модель, а затем преобразовать её в единый блок символов:

# Преобразуем модель в набор символов

fit_char <- rawToChar(serialize(fit, NULL, TRUE))
nchar(fit_char) # сколько там символов?

Вот в таком виде (fit_char) мы можем сохранять модель в базе данных.

SQLite

library(DBI)

# Создаём базу данных.

db <- dbConnect(RSQLite::SQLite(), dbname="models.sqlite")

# Создаём таблицу с атрибутами:
#   'id',
#   'model' - собственно модель, сохраняемая как 'VARCHAR(2000)'.
dbGetQuery(db, 'CREATE TABLE IF NOT EXISTS models
           (id INT PRIMARY KEY,
           model VARCHAR(2000))'
          )

# Создаем data.frame для вставки в БД.

df <- data.frame(id = 1, mdl = fit_char)

# Вставляем данные в таблицу БД.

dbGetPreparedQuery(db, 'INSERT INTO models (model) values (:mdl)',
                   bind.data = df)

PostgreSQL

Здесь таблица будет иметь больше колонок и мы предполагаем, что она уже создана -- нужно её только заполнить.

library(DBI)

# Подключаемся к базе данных.

db <- dbConnect((RPostgreSQL::PostgreSQL(), user="xxxxxx", password="yyyyyy", host="localhost", dbname="test")

# Выводим список существующих таблиц.

dbListTables(db)

# Проверяем, существует ли таблица?
# c("lea","modls") - с(схема,таблица)

dbExistsTable(db, c("lea","modls"))

# Вот как она устроена:
# CREATE TABLE lea.modls
# (
#   modid character(15) NOT NULL, -- Идентификатор модели
#   dscra character(20),          -- Краткое описание модели
#   model text
#   CONSTRAINT modls_pkey PRIMARY KEY (modid),
# )

# Создаем data.frame для вставки в БД.

df <- data.frame(id = "iris_lda", dscr="LDA", mdl = fit_char)

# Вставляем данные в таблицу БД.

dbWriteTable(db, c("lea","modls"), df, row.names=FALSE, append=TRUE)

Импорт модели и работа с ней

Теперь извлечём модель из базы, разметим с её помощью тестовую выборку и оценим достигнутое качество классификации.

Самая важная операция на этом этапе -- восстановление модели в виде объекта R (model2).

SQLite

# Извлекаем данные из таблицы БД.

df2 <- dbGetQuery(db, "SELECT * FROM models")

# Восстанавливаем представление модели в R.

model2 <- unserialize(charToRaw(df2$model))

# Выполняем классификацию с помощью восстановленной модели.

prediction <- predict(model2, newdata = testSet[,-5])

# Проверяем качество классификации.

confusionMatrix(prediction$class, testSet[,5])

# Разрываем соединение.

dbDisconnect(db)

# Удаляем таблицы перед следующим тестовым запуском.
# dbGetQuery(db, "DROP TABLE IF EXISTS models")

PostgreSQL

# Извлекаем данные из таблицы БД.
# Всю таблицу:
# df2 <- dbGetQuery(db, "SELECT * FROM lea.modls")
# Модель с заданным id:
df2 <- dbGetQuery(db, "SELECT * FROM lea.modls WHERE modid='iris_lda'")

# 
# Дальнейшие шаги полностью повторяют то, что мы делали в SQLite.
#


Комментарии

comments powered by Disqus