-
Notifications
You must be signed in to change notification settings - Fork 4
RETURNN example config
Albert Zeyer edited this page Oct 25, 2022
·
1 revision
See a full demo here: https://github.com/rwth-i6/returnn_common/blob/main/demos/nn-model.returnn-config.py
Put this into some returnn.config.py
file or so:
#!returnn.py
import sys
from typing import Any, Dict
import os
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim
sys.path.insert(0, ...) # make sure returnn_common can be imported...
demo_name, _ = os.path.splitext(__file__)
print("Hello, experiment: %s" % demo_name)
use_tensorflow = True
task = "train"
train = {"class": "Task12AXDataset", "num_seqs": 1000}
dev = {"class": "Task12AXDataset", "num_seqs": 100, "fixed_random_seed": 1}
time_dim = SpatialDim("time")
feature_dim = FeatureDim("input", 9)
classes_dim = FeatureDim("classes", 2)
default_input = "data"
target = "classes"
extern_data = {
"data": {"dim_tags": [batch_dim, time_dim, feature_dim]},
"classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": classes_dim},
}
# model / network
def get_network(*, epoch: int, **_kwargs_unused) -> Dict[str, Any]:
"""called from the RETURNN config"""
epoch # noqa # unused
from returnn_common import nn
nn.reset_default_root_name_ctx()
data = nn.Data(name=default_input, **extern_data[default_input])
targets = nn.Data(name=target, **extern_data[target])
data = nn.get_extern_data(data)
targets = nn.get_extern_data(targets)
# We define a simple LSTM network.
# This is similar as the pure-RETURNN demo-tf-native-lstm.12ax.config.
class Model(nn.Module):
"""LSTM"""
def __init__(self):
super().__init__()
hidden_dim = nn.FeatureDim("hidden", 10)
self.lstm = nn.LSTM(feature_dim, hidden_dim)
self.projection = nn.Linear(hidden_dim, classes_dim)
def __call__(self, x: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor:
x = nn.dropout(x, dropout=0.1, axis=feature_dim)
x, _ = self.lstm(x, spatial_dim=spatial_dim)
x = self.projection(x)
return x # logits
net = Model()
logits = net(data, spatial_dim=time_dim)
loss = nn.sparse_softmax_cross_entropy_with_logits(logits=logits, targets=targets, axis=classes_dim)
loss.mark_as_loss("ce")
net_dict = nn.get_returnn_config().get_net_dict_raw_dict(root_module=net)
return net_dict
# batching
batching = "random"
batch_size = 5000
max_seqs = 10
chunking = "200:200"
# training
optimizer = {"class": "adam"}
learning_rate = 0.01
model = "/tmp/%s/returnn/%s/model" % (get_login_username(), demo_name)
num_epochs = 5
# log
log_verbosity = 3