Skip to content

Commit

Permalink
tested process workload and internal tool calling
Browse files Browse the repository at this point in the history
  • Loading branch information
wilsonwang371 committed Jan 26, 2025
1 parent 9b0c17f commit f663139
Show file tree
Hide file tree
Showing 15 changed files with 290 additions and 121 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ This command will:
To run SPEAR in local mode, use the following command:

```bash
# if you are using openai official api, you can set OPENAI_API_BASE=https://api.openai.com/v1
export OPENAI_API_BASE=<YOUR_OPENAI_API_BASE>
export OPENAI_API_KEY=<YOUR_OPENAI_API_KEY>
bin/spearlet exec -n pyconversation
```
Expand Down
60 changes: 40 additions & 20 deletions cmd/spearlet/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type SpearletConfig struct {
var (
execRtTypeStr string
execWorkloadName string
execProcFileName string
execReqMethod string
execReqPayload string

Expand Down Expand Up @@ -51,11 +52,15 @@ func NewRootCmd() *cobra.Command {
"process": task.TaskTypeProcess,
"dylib": task.TaskTypeDylib,
"wasm": task.TaskTypeWasm,
"unknown": task.TaskTypeUnknown,
}

if execWorkloadName == "" {
if execWorkloadName == "" && execProcFileName == "" {
log.Errorf("Invalid workload name %s", execWorkloadName)
return
} else if execProcFileName != "" && execWorkloadName != "" {
log.Errorf("Cannot specify both workload name and process filename at the same time")
return
}
if execReqMethod == "" {
log.Errorf("Invalid request method %s", execReqMethod)
Expand All @@ -80,32 +85,47 @@ func NewRootCmd() *cobra.Command {
w := spearlet.NewSpearlet(config)
w.Initialize()

// lookup task id
execWorkloadId, err := w.LookupTaskId(execWorkloadName)
if err != nil {
log.Errorf("Error looking up task id: %v", err)
// print available tasks
tasks := w.ListTasks()
log.Infof("Available tasks: %v", tasks)
defer func() {
w.Stop()
return
}

res, err := w.ExecuteTask(execWorkloadId, rtType, true, execReqMethod, execReqPayload)
if err != nil {
log.Errorf("Error executing workload: %v", err)
}()
if execWorkloadName != "" {
// lookup task id
execWorkloadId, err := w.LookupTaskId(execWorkloadName)
if err != nil {
log.Errorf("Error looking up task id: %v", err)
// print available tasks
tasks := w.ListTasks()
log.Infof("Available tasks: %v", tasks)
return
} else {
res, err := w.ExecuteTask(execWorkloadId, rtType, true, execReqMethod, execReqPayload)
if err != nil {
log.Errorf("Error executing workload: %v", err)
return
}
log.Debugf("Workload execution result: %v", res)
}
} else if execProcFileName != "" {
res, err := w.ExecuteTaskNoMeta(execProcFileName, rtType, true, execReqMethod, execReqPayload)
if err != nil {
log.Errorf("Error executing workload: %v", err)
return
}
log.Debugf("Workload execution result: %v", res)
}
log.Debugf("Workload execution result: %v", res)
w.Stop()
// TODO: implement workload execution
}
},
}

// workload id
execCmd.PersistentFlags().StringVarP(&execWorkloadName, "name", "n", "", "workload name")
// workload name
execCmd.PersistentFlags().StringVarP(&execWorkloadName, "name", "n", "",
"workload name. Cannot be used with process workload filename at the same time")
// workload filename
execCmd.PersistentFlags().StringVarP(&execProcFileName, "file", "f", "",
"process workload filename. Only valid for process type workload")
// workload type, a choice of Docker, Process, Dylib or Wasm
execCmd.PersistentFlags().StringVarP(&execRtTypeStr, "type", "t", "Docker", "type of the workload")
execCmd.PersistentFlags().StringVarP(&execRtTypeStr, "type", "t", "unknown",
"type of the workload. By default, it is unknown and the spearlet will try to determine the type.")
// workload request payload
execCmd.PersistentFlags().StringVarP(&execReqMethod, "method", "m", "handle", "default method to invoke")
execCmd.PersistentFlags().StringVarP(&execReqPayload, "payload", "p", "", "request payload")
Expand Down
1 change: 1 addition & 0 deletions proto/io/record-req.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ namespace spear.proto.io;
table RecordRequest {
prompt: string (required);
model: string;
dryrun: bool=false;
}

root_type RecordRequest;
7 changes: 4 additions & 3 deletions sdk/python/spear/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import flatbuffers as fbs

from spear.proto.custom import CustomRequest
from spear.proto.tool import ToolInvocationRequest, InternalToolInfo, ToolInfo, ToolInvocationResponse
from spear.proto.transport import (Method, TransportMessageRaw,
from spear.proto.tool import (InternalToolInfo, ToolInfo,
ToolInvocationRequest, ToolInvocationResponse)
from spear.proto.transport import (Method, Signal, TransportMessageRaw,
TransportMessageRaw_Data, TransportRequest,
TransportResponse, TransportSignal, Signal)
TransportResponse, TransportSignal)

MAX_INFLIGHT_REQUESTS = 128
DEFAULT_MESSAGE_SIZE = 4096
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/spear/transform/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from spear.proto.chat import (ChatCompletionRequest, ChatCompletionResponse,
ChatMessage, ChatMetadata, Role)
from spear.proto.chat import ToolInfo as ChatToolInfo
from spear.proto.tool import BuiltinToolInfo, ToolInfo, InternalToolInfo
from spear.proto.tool import BuiltinToolInfo, InternalToolInfo, ToolInfo
from spear.proto.transform import (TransformOperation, TransformRequest,
TransformRequest_Params, TransformResponse,
TransformResponse_Data, TransformType)
Expand Down
3 changes: 2 additions & 1 deletion sdk/python/spear/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def speak(


def record(agent: client.HostAgent, prompt: str,
model: str = "whisper-1") -> str:
model: str = "whisper-1", dryrun=False) -> str:
"""
get user input
"""
Expand All @@ -83,6 +83,7 @@ def record(agent: client.HostAgent, prompt: str,
RecordRequest.RecordRequestAddPrompt(builder, prompt_off)
if model:
RecordRequest.RecordRequestAddModel(builder, model_off)
RecordRequest.RecordRequestAddDryrun(builder, dryrun)
data_off = RecordRequest.RecordRequestEnd(builder)
builder.Finish(data_off)
res = agent.exec_request(
Expand Down
7 changes: 4 additions & 3 deletions sdk/python/spear/utils/tool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python3
import logging
import inspect
import logging

import flatbuffers as fbs
import spear.client as client

from spear.proto.tool import (
InternalToolCreateRequest, InternalToolCreateResponse, InternalToolCreateParamSpec)
from spear.proto.tool import (InternalToolCreateParamSpec,
InternalToolCreateRequest,
InternalToolCreateResponse)
from spear.proto.transport import Method

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion spearlet/hostcalls/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ func innerChatCompletion(inv *hcommon.InvocationInfo, chatReq *chat.ChatCompleti
})

toolCalls := choice.Message.ToolCalls
log.Infof("Tool calls: %d", len(toolCalls))
log.Debugf("Tool calls amount: %d", len(toolCalls))
for _, toolCall := range toolCalls {
log.Infof("Tool call: %s", toolCall.Function.Name)
argsStr := toolCall.Function.Arguments
Expand Down
148 changes: 78 additions & 70 deletions spearlet/hostcalls/common/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@ package common

import (
"os"
"strings"

log "github.com/sirupsen/logrus"
)

type OpenAIFunctionType int

type APIEndpointInfo struct {
Name string
Model string
Base *string
APIKey string
Url string
Name string
Model string
Base *string
APIKey string // used if APIKeyInEnv is empty
APIKeyInEnv string // if not empty, the API key is in env var
Url string
}

const (
Expand Down Expand Up @@ -42,18 +42,18 @@ var (
APIEndpointMap = map[OpenAIFunctionType][]APIEndpointInfo{
OpenAIFunctionTypeChatWithTools: {
{
Name: "deepseek-toolchat",
Model: "deepseek-chat",
Base: &DeepSeekBase,
APIKey: os.Getenv("DEEPSEEK_API_KEY"),
Url: "/chat/completions",
Name: "deepseek-toolchat",
Model: "deepseek-chat",
Base: &DeepSeekBase,
APIKeyInEnv: "DEEPSEEK_API_KEY",
Url: "/chat/completions",
},
{
Name: "openai-toolchat",
Model: "gpt-4o",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/chat/completions",
Name: "openai-toolchat",
Model: "gpt-4o",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/chat/completions",
},
{
Name: "qwen-toolchat-72b",
Expand Down Expand Up @@ -93,11 +93,11 @@ var (
},
OpenAIFunctionTypeChatOnly: {
{
Name: "openai-chat",
Model: "gpt-4o",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/chat/completions"},
Name: "openai-chat",
Model: "gpt-4o",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/chat/completions"},
{
Name: "llama-chat",
Model: "llama",
Expand All @@ -108,11 +108,11 @@ var (
},
OpenAIFunctionTypeEmbeddings: {
{
Name: "openai-embed",
Model: "text-embedding-ada-002",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/embeddings",
Name: "openai-embed",
Model: "text-embedding-ada-002",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/embeddings",
},
{
Name: "nomic-embed",
Expand All @@ -124,20 +124,20 @@ var (
},
OpenAIFunctionTypeTextToSpeech: {
{
Name: "openai-tts",
Model: "tts-1",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/audio/speech",
Name: "openai-tts",
Model: "tts-1",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/audio/speech",
},
},
OpenAIFunctionTypeImageGeneration: {
{
Name: "openai-genimage",
Model: "dall-e-3",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/images/generations",
Name: "openai-genimage",
Model: "dall-e-3",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/images/generations",
},
},
OpenAIFunctionTypeSpeechToText: {
Expand All @@ -149,11 +149,11 @@ var (
Url: "/audio/transcriptions",
},
{
Name: "openai-whisper",
Model: "whisper-1",
Base: &OpenAIBase,
APIKey: os.Getenv("OPENAI_API_KEY"),
Url: "/audio/transcriptions",
Name: "openai-whisper",
Model: "whisper-1",
Base: &OpenAIBase,
APIKeyInEnv: "OPENAI_API_KEY",
Url: "/audio/transcriptions",
},
},
}
Expand All @@ -169,40 +169,48 @@ func GetAPIEndpointInfo(ft OpenAIFunctionType, modelOrName string) []APIEndpoint
res = append(res, info)
}
}
tmpList := make([]APIEndpointInfo, 0)

// remove if the api key is from env but not set
res2 := make([]APIEndpointInfo, 0)
for _, e := range res {
tmp := &APIEndpointInfo{
Name: e.Name,
Model: e.Model,
Base: e.Base,
APIKey: "********",
Url: e.Url,
}
if e.APIKey == "" {
tmp.APIKey = ""
if e.APIKeyInEnv != "" {
key := os.Getenv(e.APIKeyInEnv)
if key == "" {
// skip if the key is not set
continue
}
res2 = append(res2, e)
res2[len(res2)-1].APIKey = key
} else {
res2 = append(res2, e)
}
tmpList = append(tmpList, *tmp)
}
log.Infof("Found %d endpoint(s) for %s: %v", len(tmpList), modelOrName, tmpList)
return res
}

func init() {
if os.Getenv("OPENAI_API_KEY") == "" {
log.Warnf("OPENAI_API_KEY not set, disabling openai functions")
newAPIEndpointMap := map[OpenAIFunctionType][]APIEndpointInfo{}
for ft, infoList := range APIEndpointMap {
newInfoList := []APIEndpointInfo{}
for _, info := range infoList {
// copy data only if the name does not contain "openai"
if !strings.Contains(info.Name, "openai") {
newInfoList = append(newInfoList, info)
}
func() {
// print the endpoint info found
tmpList := make([]APIEndpointInfo, 0)
for _, e := range res2 {
tmp := &APIEndpointInfo{
Name: e.Name,
Model: e.Model,
Base: e.Base,
APIKey: "********",
Url: e.Url,
}
newAPIEndpointMap[ft] = newInfoList
if e.APIKey == "" {
tmp.APIKey = ""
}
tmpList = append(tmpList, *tmp)
}
APIEndpointMap = newAPIEndpointMap
} else {
OpenAIBase = "https://api.openai.com/v1"
log.Infof("Found %d endpoint(s) for %s: %v", len(tmpList), modelOrName, tmpList)
}()

return res2
}

func init() {
if os.Getenv("OPENAI_API_BASE") != "" {
// official "https://api.openai.com/v1"
OpenAIBase = os.Getenv("OPENAI_API_BASE")
}
}
Loading

0 comments on commit f663139

Please sign in to comment.