Skip to content

RETURNN example config

Albert Zeyer edited this page Oct 25, 2022 · 1 revision

See a full demo here: https://github.com/rwth-i6/returnn_common/blob/main/demos/nn-model.returnn-config.py

Put this into some returnn.config.py file or so:

#!returnn.py

import sys
from typing import Any, Dict
import os
from returnn.tf.util.data import batch_dim, SpatialDim, FeatureDim

sys.path.insert(0, ...)  # make sure returnn_common can be imported...

demo_name, _ = os.path.splitext(__file__)
print("Hello, experiment: %s" % demo_name)

use_tensorflow = True

task = "train"
train = {"class": "Task12AXDataset", "num_seqs": 1000}
dev = {"class": "Task12AXDataset", "num_seqs": 100, "fixed_random_seed": 1}

time_dim = SpatialDim("time")
feature_dim = FeatureDim("input", 9)
classes_dim = FeatureDim("classes", 2)
default_input = "data"
target = "classes"
extern_data = {
  "data": {"dim_tags": [batch_dim, time_dim, feature_dim]},
  "classes": {"dim_tags": [batch_dim, time_dim], "sparse_dim": classes_dim},
}


# model / network
def get_network(*, epoch: int, **_kwargs_unused) -> Dict[str, Any]:
    """called from the RETURNN config"""
    epoch  # noqa  # unused
    from returnn_common import nn
    nn.reset_default_root_name_ctx()
    data = nn.Data(name=default_input, **extern_data[default_input])
    targets = nn.Data(name=target, **extern_data[target])
    data = nn.get_extern_data(data)
    targets = nn.get_extern_data(targets)

    # We define a simple LSTM network.
    # This is similar as the pure-RETURNN demo-tf-native-lstm.12ax.config.
    class Model(nn.Module):
      """LSTM"""
      def __init__(self):
        super().__init__()
        hidden_dim = nn.FeatureDim("hidden", 10)
        self.lstm = nn.LSTM(feature_dim, hidden_dim)
        self.projection = nn.Linear(hidden_dim, classes_dim)

      def __call__(self, x: nn.Tensor, *, spatial_dim: nn.Dim) -> nn.Tensor:
        x = nn.dropout(x, dropout=0.1, axis=feature_dim)
        x, _ = self.lstm(x, spatial_dim=spatial_dim)
        x = self.projection(x)
        return x  # logits

    net = Model()
    logits = net(data, spatial_dim=time_dim)
    loss = nn.sparse_softmax_cross_entropy_with_logits(logits=logits, targets=targets, axis=classes_dim)
    loss.mark_as_loss("ce")

    net_dict = nn.get_returnn_config().get_net_dict_raw_dict(root_module=net)
    return net_dict


# batching
batching = "random"
batch_size = 5000
max_seqs = 10
chunking = "200:200"

# training
optimizer = {"class": "adam"}
learning_rate = 0.01
model = "/tmp/%s/returnn/%s/model" % (get_login_username(), demo_name)
num_epochs = 5

# log
log_verbosity = 3
Clone this wiki locally