1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
optimizer = optimizers.Nadam()
loss_func = losses.SparseCategoricalCrossentropy()
train_loss = metrics.Mean(name='train_loss')
train_metric = metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = metrics.Mean(name='valid_loss')
valid_metric = metrics.SparseCategoricalAccuracy(name='valid_accuracy')
@tf.function
def train_step(model, features, labels):
with tf.GradientTape() as tape:
predictions = model(features,training = True)
loss = loss_func(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss.update_state(loss)
train_metric.update_state(labels, predictions)
@tf.function
def valid_step(model, features, labels):
predictions = model(features)
batch_loss = loss_func(labels, predictions)
valid_loss.update_state(batch_loss)
valid_metric.update_state(labels, predictions)
def train_model(model,ds_train,ds_valid,epochs):
for epoch in tf.range(1,epochs+1):
for features, labels in ds_train:
train_step(model,features,labels)
for features, labels in ds_valid:
valid_step(model,features,labels)
logs = 'Epoch={},Loss:{},Accuracy:{},Valid Loss:{},Valid Accuracy:{}'
if epoch%1 ==0:
printbar()
tf.print(tf.strings.format(logs,
(epoch,train_loss.result(),train_metric.result(),valid_loss.result(),valid_metric.result())))
tf.print("")
train_loss.reset_states()
valid_loss.reset_states()
train_metric.reset_states()
valid_metric.reset_states()
train_model(model,ds_train,ds_test,10)
|