Вопрос:

Нейронная сеть повторного использования Tensorflow

python tensorflow

364 просмотра

1 ответ

11 Репутация автора

Я новичок в тензорном потоке, и я тренировал простую нейронную сеть, но после обучения я не знаю, как повторно использовать NN для получения выходных данных.

def train_neural_network(x,y,aDataTrain,aTargetTrain,aDataTest,aTargetTest):
    batch_size = 500
    prediction = neural_network_model(x,len(aDataTrain[0]))
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
    optimizer = tf.train.AdamOptimizer().minimize(cost)
    hm_epochs = 1

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(hm_epochs):
            epoch_loss = 0
            i = 0
            while i < len(aDataTrain):
                start = i
                end = i + batch_size
                batch_x = np.array(aDataTrain[start:end])
                batch_y = np.array(aTargetTrain[start:end])

                _,c = sess.run([optimizer,cost],feed_dict={x:batch_x,y:batch_y})
                epoch_loss += c
                i += batch_size
            print ("Epoch", epoch, "completed out of", hm_epochs, "loss", epoch_loss)

        correct =tf.equal(tf.argmax(prediction,1), tf.argmax(y,1))

        accurracy = tf.reduce_mean(tf.cast(correct,'float'))
        finalAcc = accurracy.eval({x:aDataTest,y:aTargetTest})
        saver.save(sess, 'model/model.ckpt')

    print("Accuracy:",finalAcc)

Итак, после того, как я сохранил модель и попытался ее восстановить, я не знаю, как продолжить получать выходные данные NN из «input_data».

def execute_neural_network(x,y,aDataTrain,aTargetTrain,aDataTest,aTargetTest):
    batch_size = 1
    y_pred = []

    prediction = neural_network_model(x,len(aDataTrain[0]))
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
    optimizer = tf.train.AdamOptimizer().minimize(cost)

    input_data = [5.0, 3.0, 1.0, 5.0, 6.0, 5.0, 2.0, 4.0, 7.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 3.0, 2.0, 3.0, 2.0, 3.0, 3.0, 4.0, 3.0, 3.0, 2.0, 4.0, 3.0, 3.0, 2.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 61.0, 21.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 75.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 35.0, 11.0, 10.0, 33.0, 24.0, 6.0, 2.0, 2.0, 3.0, 4.0, 3.0, 3.0, 8.0, 6.0, 5.0, 6.0, 5.0, 8.0, 9.0, 13.0, 7.0, 25.0, 11.0, 2.0, 2.0, 2.0, 2.0, 2.0]

    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, 'model/model.ckpt')
        #Get neural network output from input_data
Автор: Joan Guillem Castell Ros-zanet Источник Размещён: 08.11.2017 10:09

Ответы (1)


0 плюса

3206 Репутация автора

Предполагая, что вы создаете модель графа / сети примерно так:

with tf.Session() as sess:
    #do other stuff
    predictionOp = tf.argmax(py_x, 1)

saver.save(sess, 'model') 

где predictionOpпеременная, которая является выходом вашей сети.

Впоследствии вы можете добавить что-то вроде этого: tf.add_to_collection("predictionOp", predictionOp) дать predictionOpимя, чтобы его было легче найти. Затем вы можете перезагрузить модель и получить прогнозы:

with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('model.meta')
    new_saver.restore(sess, 'model')
    predictionOp = tf.get_collection("predictionOp")[0]

    #get the prediction
    prediction = sess.run(predictionOp, feed_dict={"x:0": input_data})

Для получения дополнительной информации, пожалуйста, ознакомьтесь с tensorflow документацией и здесь для получения дополнительной информации об основах. Кроме того, есть некоторые другие темы, которые имеют дело с подобными проблемами, как эта и эта .

Автор: FlashTek Размещён: 08.11.2017 10:33
Вопросы из категории :
32x32