1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import tensorflow as tf import random from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("tmp/data/", one_hot=True) tf.set_random_seed(777) x=tf.placeholder(tf.float32, [None, 784]) y=tf.placeholder(tf.float32, [None, 10]) w=tf.Variable(tf.random_normal([784, 10], stddev=0.1)) b=tf.Variable(tf.random_normal([10], stddev=0.1)) hypothesis = tf.nn.softmax(tf.matmul(x,w)+b) cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=hypothesis, labels=y)) optimizer = tf.train.GradientDescentOptimizer(0.1).minimize(cost) is_correct = tf.equal(tf.arg_max(hypothesis,1), tf.arg_max(y,1)) accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32)) print(hypothesis) print(cost) print(optimizer) t_epoch=10 batch_size=100 with tf.Session() as sess: start = tf.global_variables_initializer() sess.run(start) total_batch = int(mnist.train.num_examples/batch_size) #total_batch = int(size(mnist)/batch_size) print(total_batch) for epoch in range(t_epoch): avg_cost=0 for i in range(total_batch): batch_x, batch_y = mnist.train.next_batch(batch_size) c, _ = sess.run([cost, optimizer], feed_dict={x:batch_x, y:batch_y}) avg_cost+=c print('{:.9f}'.format(avg_cost/total_batch)) print("정확도 : ", accuracy.eval(session=sess, feed_dict={x:mnist.test.images, y:mnist.test.labels})*100) idx = random.randint(0, mnist.test.num_examples-1) print("Label : ", sess.run(tf.argmax(mnist.test.labels[idx: idx+1], 1))) print("Prediction : ", sess.run(tf.argmax(hypothesis, 1), feed_dict={x:mnist.test.images[idx:idx+1]})) | cs |
1. 데이터 셋 : MNIST 를 사용했다
2. 레이어 갯수 : 레이어 1개로 학습했다
3. cost function : 크로스 엔트로피를 활용했다
4. optimizer : Gradient Descent 를 활용했다
5. 전체적인 라이브러리 : 텐서플로우 하나로만 실습했다
'기계학습 > 이미지 머신러닝' 카테고리의 다른 글
남자 여자 판독기 - CNN 모델 (0) | 2020.02.25 |
---|---|
MNIST 실습 - GAN (0) | 2020.02.11 |
MNIST 실습 - CNN 모델 (0) | 2020.02.10 |
MNIST - 일반 딥러닝 모델 (0) | 2020.02.09 |