Skip to content

Commit

Permalink
Add torch parser
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Oct 12, 2024
1 parent f87fe93 commit 841c0df
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 144 deletions.
64 changes: 57 additions & 7 deletions vgslify/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@ class BaseModelParser(ABC):
Provides common utility methods for parsing different frameworks and generating VGSL spec strings.
"""

@abstractmethod
def parse_model(self, model) -> str:
"""Parse the model into a VGSL spec string."""
pass

def generate_vgsl(self, configs: List[Union[
Conv2DConfig,
Pooling2DConfig,
Expand Down Expand Up @@ -89,6 +84,61 @@ def generate_vgsl(self, configs: List[Union[
# Reverse to restore the original order
return " ".join(vgsl_parts[::-1])

@abstractmethod
def parse_model(self, model) -> str:
"""Parse the model into a VGSL spec string."""
pass

@abstractmethod
def parse_input(self, layer) -> InputConfig:
"""Parse the input layer into a InputConfig dataclass."""
pass

@abstractmethod
def parse_conv2d(self, layer) -> Conv2DConfig:
"""Parse the Conv2D layer into a Conv2DConfig dataclass."""
pass

@abstractmethod
def parse_dense(self, layer) -> DenseConfig:
"""Parse the Dense layer into a DenseConfig dataclass."""
pass

@abstractmethod
def parse_rnn(self, layer) -> RNNConfig:
"""Parse the RNN layer into a RNNConfig dataclass."""
pass

@abstractmethod
def parse_pooling(self, layer) -> Pooling2DConfig:
"""Parse the Pooling layer into a Pooling2DConfig dataclass."""
pass

@abstractmethod
def parse_batchnorm(self, layer) -> str:
"""Parse the BatchNorm layer into a VGSL spec string."""
pass

@abstractmethod
def parse_dropout(self, layer) -> DropoutConfig:
"""Parse the Dropout layer into a DropoutConfig dataclass."""
pass

@abstractmethod
def parse_flatten(self, layer) -> str:
"""Parse the Flatten layer into a VGSL spec string."""
pass

@abstractmethod
def parse_reshape(self, layer) -> ReshapeConfig:
"""Parse the Reshape layer into a ReshapeConfig dataclass."""
pass

@abstractmethod
def parse_activation(self, layer) -> ActivationConfig:
"""Parse the Activation layer into a ActivationConfig dataclass."""
pass

# VGSL Generation Methods
def _vgsl_input(self, config: InputConfig) -> str:
return ",".join(map(str, filter(lambda x: x != -1, [
Expand Down Expand Up @@ -156,9 +206,9 @@ def _vgsl_activation(self, config: ActivationConfig) -> str:
def _get_activation_code(self, activation: str) -> str:
ACTIVATION_MAP = {
'softmax': 's', 'tanh': 't', 'relu': 'r',
'linear': 'l', 'sigmoid': 'm'
'linear': 'l', 'sigmoid': 'm', 'identity': 'l'
}
act_code = ACTIVATION_MAP.get(activation.lower(), None)
if act_code is None:
raise ValueError(f"Unsupported activation '{activation}'.")
return act_code
return act_code
166 changes: 71 additions & 95 deletions vgslify/parsers/tf_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import tensorflow as tf
# Imports

# > Standard Library
from typing import Callable, Dict, Type, Union

# > Third-Party Dependencies
import tensorflow as tf

# > Internal
from vgslify.core.config import (
ActivationConfig,
Conv2DConfig,
Expand All @@ -20,14 +27,14 @@ class TensorFlowModelParser(BaseModelParser):
def __init__(self):
# Initialize the layer parsers mapping
self.layer_parsers: Dict[Type[tf.keras.layers.Layer], Callable] = {
tf.keras.layers.InputLayer: self.parse_input_layer,
tf.keras.layers.InputLayer: self.parse_input,
tf.keras.layers.Conv2D: self.parse_conv2d,
tf.keras.layers.Dense: self.parse_dense,
tf.keras.layers.LSTM: lambda layer: self.parse_rnn(layer, "lstm"),
tf.keras.layers.GRU: lambda layer: self.parse_rnn(layer, "gru"),
tf.keras.layers.Bidirectional: self.parse_bidirectional,
tf.keras.layers.MaxPooling2D: lambda layer: self.parse_pooling(layer, "max"),
tf.keras.layers.AveragePooling2D: lambda layer: self.parse_pooling(layer, "average"),
tf.keras.layers.LSTM: self.parse_rnn,
tf.keras.layers.GRU: self.parse_rnn,
tf.keras.layers.Bidirectional: self.parse_rnn,
tf.keras.layers.MaxPooling2D: self.parse_pooling,
tf.keras.layers.AveragePooling2D: self.parse_pooling,
tf.keras.layers.BatchNormalization: self.parse_batchnorm,
tf.keras.layers.Dropout: self.parse_dropout,
tf.keras.layers.Reshape: self.parse_reshape,
Expand Down Expand Up @@ -62,13 +69,13 @@ def parse_model(self, model: tf.keras.models.Model) -> str:
input_shape=model.input_shape[1:],
batch_size=model.input_shape[0]
)
input_config = self.parse_input_layer(input_layer)
input_config = self.parse_input(input_layer)
configs.append(input_config)

# Iterate through all layers in the model
for idx, layer in enumerate(model.layers):
layer_type = type(layer)
parser_func = self._get_parser(layer_type)
parser_func = self.layer_parsers.get(layer_type, None)

if parser_func:
# Parse the layer
Expand All @@ -85,47 +92,10 @@ def parse_model(self, model: tf.keras.models.Model) -> str:
# Generate VGSL spec string from configs
return self.generate_vgsl(configs)

def _get_parser(self, layer_type: Type[tf.keras.layers.Layer]) -> Callable:
"""
Retrieve the parser function for a given layer type.
Parameters
----------
layer_type : Type[tf.keras.layers.Layer]
The type of the layer.
Returns
-------
Callable
The corresponding parser function.
"""
return self.layer_parsers.get(layer_type, None)

def _extract_activation(self, layer: tf.keras.layers.Layer) -> str:
"""
Extract the activation function from a TensorFlow Keras layer.
Parameters
----------
layer : tf.keras.layers.Layer
The layer from which to extract the activation.
Returns
-------
str
The activation function name.
"""
if hasattr(layer, 'activation') and callable(layer.activation):
activation = layer.activation.__name__
elif isinstance(layer, tf.keras.layers.Activation):
activation = layer.activation.__name__
else:
activation = 'linear'
return activation

# Parser methods for different layer types

def parse_input_layer(self, layer: tf.keras.layers.InputLayer) -> InputConfig:
def parse_input(self, layer: tf.keras.layers.InputLayer) -> InputConfig:
"""
Parse an InputLayer into an InputConfig dataclass.
Expand Down Expand Up @@ -199,62 +169,44 @@ def parse_dense(self, layer: tf.keras.layers.Dense) -> DenseConfig:
units=layer.units
)

def parse_rnn(self, layer: Union[tf.keras.layers.LSTM, tf.keras.layers.GRU], rnn_type: str) -> RNNConfig:
def parse_rnn(self, layer: Union[tf.keras.layers.LSTM,
tf.keras.layers.GRU,
tf.keras.layers.Bidirectional]) -> RNNConfig:
"""
Parse an RNN layer (LSTM or GRU) into an RNNConfig dataclass.
Parse an RNN layer (LSTM, GRU, or Bidirectional) into an RNNConfig dataclass.
Parameters
----------
layer : tf.keras.layers.LSTM or tf.keras.layers.GRU
layer : Union[tf.keras.layers.LSTM, tf.keras.layers.GRU, tf.keras.layers.Bidirectional]
The RNN layer to parse.
rnn_type : str
The type identifier ('lstm' or 'gru').
Returns
-------
RNNConfig
The configuration for the RNN layer.
"""
return RNNConfig(
units=layer.units,
return_sequences=layer.return_sequences,
go_backwards=layer.go_backwards,
dropout=layer.dropout,
recurrent_dropout=layer.recurrent_dropout,
rnn_type=rnn_type,
bidirectional=False
)

def parse_bidirectional(self, layer: tf.keras.layers.Bidirectional) -> RNNConfig:
"""
Parse a Bidirectional layer into an RNNConfig dataclass.
Parameters
----------
layer : tf.keras.layers.Bidirectional
The Bidirectional layer to parse.
if isinstance(layer, tf.keras.layers.Bidirectional):
wrapped_layer = layer.forward_layer
bidirectional = True
else:
wrapped_layer = layer
bidirectional = False

Returns
-------
RNNConfig
The configuration for the Bidirectional RNN layer.
"""
wrapped_layer = layer.forward_layer
if isinstance(wrapped_layer, tf.keras.layers.LSTM):
rnn_type = 'lstm'
elif isinstance(wrapped_layer, tf.keras.layers.GRU):
rnn_type = 'gru'
else:
raise ValueError(f"Unsupported wrapped layer type {type(wrapped_layer).__name__} in Bidirectional layer.")
raise ValueError(f"Unsupported RNN layer type {type(wrapped_layer).__name__}.")

return RNNConfig(
units=wrapped_layer.units,
return_sequences=wrapped_layer.return_sequences,
go_backwards=False, # Bidirectional layers handle both directions
go_backwards=wrapped_layer.go_backwards if not bidirectional else False,
dropout=wrapped_layer.dropout,
recurrent_dropout=wrapped_layer.recurrent_dropout,
rnn_type=rnn_type,
bidirectional=True
bidirectional=bidirectional
)

def parse_pooling(self, layer: Union[tf.keras.layers.MaxPooling2D, tf.keras.layers.AveragePooling2D], pool_type: str) -> Pooling2DConfig:
Expand Down Expand Up @@ -314,6 +266,23 @@ def parse_dropout(self, layer: tf.keras.layers.Dropout) -> DropoutConfig:
rate=layer.rate
)

def parse_flatten(self, layer: tf.keras.layers.Flatten) -> None:
"""
Parse a Flatten layer.
Since Flatten does not require a VGSL spec beyond 'Flatten', return a placeholder.
Parameters
----------
layer : tf.keras.layers.Flatten
The Flatten layer to parse.
Returns
-------
None
Indicates that the VGSL spec should include 'Flatten'.
"""
return "Flt"

def parse_reshape(self, layer: tf.keras.layers.Reshape) -> ReshapeConfig:
"""
Parse a Reshape layer into a ReshapeConfig dataclass.
Expand All @@ -333,36 +302,43 @@ def parse_reshape(self, layer: tf.keras.layers.Reshape) -> ReshapeConfig:
target_shape=target_shape
)

def parse_flatten(self, layer: tf.keras.layers.Flatten) -> None:
def parse_activation(self, layer: tf.keras.layers.Activation) -> ActivationConfig:
"""
Parse a Flatten layer.
Since Flatten does not require a VGSL spec beyond 'Flatten', return a placeholder.
Parse an Activation layer.
Parameters
----------
layer : tf.keras.layers.Flatten
The Flatten layer to parse.
layer : tf.keras.layers.Activation
The Activation layer to parse.
Returns
-------
None
Indicates that the VGSL spec should include 'Flatten'.
ActivationConfig
The configuration for the Activation layer.
"""
return "Flt"
activation = self._extract_activation(layer)
return ActivationConfig(activation=activation)

def parse_activation(self, layer: tf.keras.layers.Activation) -> None:

# Helper methods
def _extract_activation(self, layer: tf.keras.layers.Layer) -> str:
"""
Parse an Activation layer.
Extract the activation function from a TensorFlow Keras layer.
Parameters
----------
layer : tf.keras.layers.Activation
The Activation layer to parse.
layer : tf.keras.layers.Layer
The layer from which to extract the activation.
Returns
-------
str
The activation function name.
"""
activation = self._extract_activation(layer)
return ActivationConfig(activation=activation)
if hasattr(layer, 'activation') and callable(layer.activation):
activation = layer.activation.__name__
elif isinstance(layer, tf.keras.layers.Activation):
activation = layer.activation.__name__
else:
activation = 'linear'
return activation
Loading

0 comments on commit 841c0df

Please sign in to comment.