import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #只打印error的信息

def preprocess(x,y): #准备函数,对x,y进行数据转换
x=tf.cast(x,dtype=tf.float32)/255
y=tf.cast(y,dtype=tf.int32)
return x,y

(x,y),(x_test,y_test)=datasets.fashion_mnist.load_data()
print(x.shape,y.shape)

batchsz=100
db=tf.data.Dataset.from_tensor_slices((x,y))
db=db.map(preprocess).shuffle(10000).batch(batchsz)

db_tset=tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_tset=db_tset.map(preprocess).batch(batchsz)

db_iter = iter(db)
sample=next(db_iter)
print("batch:",sample[0].shape,sample[1].shape)

model = Sequential([ #Squential容器,装入的是列表
layers.Dense(256,activation=tf.nn.relu), #数据量是200960,是指[784*256]+[256]
layers.Dense(128,activation=tf.nn.relu), #数据量是32896,是指[256*64]+[128]
layers.Dense(64,activation=tf.nn.relu), #数据量是8256,是指[128*64]+[64]
layers.Dense(32,activation=tf.nn.relu), #数据量是2080,是指[64*32]+[32]
layers.Dense(10) #数据量是330,是指[32*10]+[10]
]) #总的参数量是244522,就是244522根连接,每一根连接是4个字节,也就是大约100万个字节,再除以1000,大概是100k的单元
# model.build(input_shape=[None,28*28]) #给网络一个输入的初始值
# model.summary() #打印网络结构

optimizer=optimizers.Adam(learning_rate=1e-3) #w=w-lr*grad 优化器

def main():
for epoch in range(50):
for step,(x,y) in enumerate(db):
x=tf.reshape(x,[-1,28*28]) #x:[b,28,28]=>[b,28*28]

with tf.GradientTape() as tape:
logits = model(x) #[b,784]=>[b,10]
y_onehot=tf.one_hot(y, depth=10)
loss_mse=tf.reduce_mean(tf.losses.MSE(y_onehot, logits))
loss_ce=tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)
loss_ce=tf.reduce_mean(loss_ce)

grads=tape.gradient(loss_ce,model.trainable_variables)
optimizer.apply_gradients(zip(grads,model.trainable_variables)) #将梯度与w参数进行对应,用optimizer进行原地更新

if step %100 ==0:
print(epoch,step,"loss:",float(loss_ce),float(loss_mse))

#test
total_correct=0
total_num=0
for x,y in db_tset:
x=tf.reshape(x,[-1,28*28])
logits = model(x)
#lofits => prob [b.10]
prob = tf.nn.softmax(logits,axis=1) #将实数范围转换为概率范围,且总和为1
pred = tf.argmax(prob,axis=1)
pred = tf.cast(pred,dtype=tf.int32)
#pred:[b]
#y:[b]
#correct:[b] True:equal, False:not equal
correct=tf.equal(y,pred)
correct=tf.reduce_sum(tf.cast(correct,dtype=tf.int32))
total_correct+=int(correct)
total_num+= x.shape[0] #将所有的batch加入进去

acc=total_correct/total_num
print(epoch,"test acc:",acc)


if __name__ == '__main__':
main()