Вопрос:

Как загрузить обученную модель MXnet?

deep-learning mxnet

2605 просмотра

1 ответ

126 Репутация автора

Я обучил сеть с использованием MXnet, но не уверен, как сохранить и загрузить параметры для последующего использования. Сначала я определяю и обучаю сеть:

    dataIn = mx.sym.var('data')
    fc1 = mx.symbol.FullyConnected(data=dataIn, num_hidden=100)
    act1 = mx.sym.Activation(data=fc1, act_type="relu")
    fc2 = mx.symbol.FullyConnected(data=act1, num_hidden=50)
    act2 = mx.sym.Activation(data=fc2, act_type="relu")
    fc3 = mx.symbol.FullyConnected(data=act2, num_hidden=25)
    act3 = mx.sym.Activation(data=fc3, act_type="relu")
    fc4 = mx.symbol.FullyConnected(data=act3, num_hidden=10)
    act4 = mx.sym.Activation(data=fc4, act_type="relu")
    fc5 = mx.symbol.FullyConnected(data=act4, num_hidden=2)
    lenet = mx.sym.SoftmaxOutput(data=fc5, name='softmax',normalization = 'batch')


# create iterator around training and validation data
train_iter = mx.io.NDArrayIter(data=data[:ntrain], label = phen[:ntrain],batch_size=batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data=data[ntrain:], label=phen[ntrain:], batch_size=batch_size)

# create a trainable module on GPU 0
lenet_model = mx.mod.Module(symbol=lenet, context=mx.gpu())
# train with the same
lenet_model.fit(train_iter,
                eval_data=val_iter,
                optimizer='adam',
                optimizer_params={'learning_rate':0.00001},
                eval_metric='f1',
                batch_end_callback = mx.callback.Speedometer(batch_size, 10),
                num_epoch=1000)

Эта модель хорошо работает на тестовом наборе, поэтому я хочу сохранить его. Далее я сохраняю макет сети и параметризацию:

lenet.save('./testNet_symbol.mxnet')
lenet_model.save_params('./testNet_module.mxnet')

Кажется, что вся документация, которую я могу найти при загрузке сети, реализовала функцию сохранения в обучающей программе, чтобы сохранять параметры сети в конце каждой эпохи. Я не устанавливал эти контрольные точки во время процесса обучения. Другие методы используют класс mx.model.FeedForward, который не кажется подходящим. Еще другие методы загружают сеть из файла .json, которого у меня нет в результате моих функций сохранения. Как я могу сохранить / загрузить сеть после того, как она уже закончила обучение?

Автор: Nuclear Wang Источник Размещён: 08.11.2017 10:14

Ответы (1)


3 плюса

31 Репутация автора

Вы просто должны сделать это вместо того, чтобы сохранить:

lenet_model.save_checkpoint('lenet', num_epoch, save_optimizer_states=True)

Это создаст 3 файла, если флаг состояния установлен в True, иначе 2 файла:

.params (веса), .json (символ), .states

И это загрузить:

lenet_model = mx.mod.Module.load(prefix,epoch)
lenet_model.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
Автор: Mohammad Размещён: 14.11.2017 07:43
Вопросы из категории :
32x32