From f663139a42c2c1b782e030d2a9c375c3f25aa723 Mon Sep 17 00:00:00 2001 From: Wilson Wang Date: Sun, 26 Jan 2025 15:40:08 +0800 Subject: [PATCH] tested process workload and internal tool calling --- README.md | 2 + cmd/spearlet/main.go | 60 ++++--- proto/io/record-req.fbs | 1 + sdk/python/spear/client.py | 7 +- sdk/python/spear/transform/chat.py | 2 +- sdk/python/spear/utils/io.py | 3 +- sdk/python/spear/utils/tool.py | 7 +- spearlet/hostcalls/chat.go | 2 +- spearlet/hostcalls/common/models.go | 148 +++++++++--------- spearlet/hostcalls/io.go | 5 + spearlet/hostcalls/openai/openai_hc.go | 6 +- spearlet/spearlet.go | 112 ++++++++++++- spearlet/task/proc_rt.go | 6 +- .../python/pytest-functionality/src/start.py | 37 +++-- .../process/python/pytest-functionality.py | 13 +- 15 files changed, 290 insertions(+), 121 deletions(-) diff --git a/README.md b/README.md index f6bcdab..5b3dc48 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,8 @@ This command will: To run SPEAR in local mode, use the following command: ```bash +# if you are using openai official api, you can set OPENAI_API_BASE=https://api.openai.com/v1 +export OPENAI_API_BASE= export OPENAI_API_KEY= bin/spearlet exec -n pyconversation ``` diff --git a/cmd/spearlet/main.go b/cmd/spearlet/main.go index 40846af..397dc4f 100644 --- a/cmd/spearlet/main.go +++ b/cmd/spearlet/main.go @@ -20,6 +20,7 @@ type SpearletConfig struct { var ( execRtTypeStr string execWorkloadName string + execProcFileName string execReqMethod string execReqPayload string @@ -51,11 +52,15 @@ func NewRootCmd() *cobra.Command { "process": task.TaskTypeProcess, "dylib": task.TaskTypeDylib, "wasm": task.TaskTypeWasm, + "unknown": task.TaskTypeUnknown, } - if execWorkloadName == "" { + if execWorkloadName == "" && execProcFileName == "" { log.Errorf("Invalid workload name %s", execWorkloadName) return + } else if execProcFileName != "" && execWorkloadName != "" { + log.Errorf("Cannot specify both workload name and process filename at the same time") + return } if execReqMethod == "" { log.Errorf("Invalid request method %s", execReqMethod) @@ -80,32 +85,47 @@ func NewRootCmd() *cobra.Command { w := spearlet.NewSpearlet(config) w.Initialize() - // lookup task id - execWorkloadId, err := w.LookupTaskId(execWorkloadName) - if err != nil { - log.Errorf("Error looking up task id: %v", err) - // print available tasks - tasks := w.ListTasks() - log.Infof("Available tasks: %v", tasks) + defer func() { w.Stop() - return - } - - res, err := w.ExecuteTask(execWorkloadId, rtType, true, execReqMethod, execReqPayload) - if err != nil { - log.Errorf("Error executing workload: %v", err) + }() + if execWorkloadName != "" { + // lookup task id + execWorkloadId, err := w.LookupTaskId(execWorkloadName) + if err != nil { + log.Errorf("Error looking up task id: %v", err) + // print available tasks + tasks := w.ListTasks() + log.Infof("Available tasks: %v", tasks) + return + } else { + res, err := w.ExecuteTask(execWorkloadId, rtType, true, execReqMethod, execReqPayload) + if err != nil { + log.Errorf("Error executing workload: %v", err) + return + } + log.Debugf("Workload execution result: %v", res) + } + } else if execProcFileName != "" { + res, err := w.ExecuteTaskNoMeta(execProcFileName, rtType, true, execReqMethod, execReqPayload) + if err != nil { + log.Errorf("Error executing workload: %v", err) + return + } + log.Debugf("Workload execution result: %v", res) } - log.Debugf("Workload execution result: %v", res) - w.Stop() - // TODO: implement workload execution } }, } - // workload id - execCmd.PersistentFlags().StringVarP(&execWorkloadName, "name", "n", "", "workload name") + // workload name + execCmd.PersistentFlags().StringVarP(&execWorkloadName, "name", "n", "", + "workload name. Cannot be used with process workload filename at the same time") + // workload filename + execCmd.PersistentFlags().StringVarP(&execProcFileName, "file", "f", "", + "process workload filename. Only valid for process type workload") // workload type, a choice of Docker, Process, Dylib or Wasm - execCmd.PersistentFlags().StringVarP(&execRtTypeStr, "type", "t", "Docker", "type of the workload") + execCmd.PersistentFlags().StringVarP(&execRtTypeStr, "type", "t", "unknown", + "type of the workload. By default, it is unknown and the spearlet will try to determine the type.") // workload request payload execCmd.PersistentFlags().StringVarP(&execReqMethod, "method", "m", "handle", "default method to invoke") execCmd.PersistentFlags().StringVarP(&execReqPayload, "payload", "p", "", "request payload") diff --git a/proto/io/record-req.fbs b/proto/io/record-req.fbs index 3c79bc8..fb83bcf 100644 --- a/proto/io/record-req.fbs +++ b/proto/io/record-req.fbs @@ -3,6 +3,7 @@ namespace spear.proto.io; table RecordRequest { prompt: string (required); model: string; + dryrun: bool=false; } root_type RecordRequest; \ No newline at end of file diff --git a/sdk/python/spear/client.py b/sdk/python/spear/client.py index 9f32be7..0c19850 100644 --- a/sdk/python/spear/client.py +++ b/sdk/python/spear/client.py @@ -13,10 +13,11 @@ import flatbuffers as fbs from spear.proto.custom import CustomRequest -from spear.proto.tool import ToolInvocationRequest, InternalToolInfo, ToolInfo, ToolInvocationResponse -from spear.proto.transport import (Method, TransportMessageRaw, +from spear.proto.tool import (InternalToolInfo, ToolInfo, + ToolInvocationRequest, ToolInvocationResponse) +from spear.proto.transport import (Method, Signal, TransportMessageRaw, TransportMessageRaw_Data, TransportRequest, - TransportResponse, TransportSignal, Signal) + TransportResponse, TransportSignal) MAX_INFLIGHT_REQUESTS = 128 DEFAULT_MESSAGE_SIZE = 4096 diff --git a/sdk/python/spear/transform/chat.py b/sdk/python/spear/transform/chat.py index a1c15d1..1290890 100644 --- a/sdk/python/spear/transform/chat.py +++ b/sdk/python/spear/transform/chat.py @@ -11,7 +11,7 @@ from spear.proto.chat import (ChatCompletionRequest, ChatCompletionResponse, ChatMessage, ChatMetadata, Role) from spear.proto.chat import ToolInfo as ChatToolInfo -from spear.proto.tool import BuiltinToolInfo, ToolInfo, InternalToolInfo +from spear.proto.tool import BuiltinToolInfo, InternalToolInfo, ToolInfo from spear.proto.transform import (TransformOperation, TransformRequest, TransformRequest_Params, TransformResponse, TransformResponse_Data, TransformType) diff --git a/sdk/python/spear/utils/io.py b/sdk/python/spear/utils/io.py index 03e21d7..2ef6a77 100644 --- a/sdk/python/spear/utils/io.py +++ b/sdk/python/spear/utils/io.py @@ -70,7 +70,7 @@ def speak( def record(agent: client.HostAgent, prompt: str, - model: str = "whisper-1") -> str: + model: str = "whisper-1", dryrun=False) -> str: """ get user input """ @@ -83,6 +83,7 @@ def record(agent: client.HostAgent, prompt: str, RecordRequest.RecordRequestAddPrompt(builder, prompt_off) if model: RecordRequest.RecordRequestAddModel(builder, model_off) + RecordRequest.RecordRequestAddDryrun(builder, dryrun) data_off = RecordRequest.RecordRequestEnd(builder) builder.Finish(data_off) res = agent.exec_request( diff --git a/sdk/python/spear/utils/tool.py b/sdk/python/spear/utils/tool.py index 82bb933..6168f75 100644 --- a/sdk/python/spear/utils/tool.py +++ b/sdk/python/spear/utils/tool.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 -import logging import inspect +import logging import flatbuffers as fbs import spear.client as client -from spear.proto.tool import ( - InternalToolCreateRequest, InternalToolCreateResponse, InternalToolCreateParamSpec) +from spear.proto.tool import (InternalToolCreateParamSpec, + InternalToolCreateRequest, + InternalToolCreateResponse) from spear.proto.transport import Method logger = logging.getLogger(__name__) diff --git a/spearlet/hostcalls/chat.go b/spearlet/hostcalls/chat.go index 9db39e5..647af65 100644 --- a/spearlet/hostcalls/chat.go +++ b/spearlet/hostcalls/chat.go @@ -375,7 +375,7 @@ func innerChatCompletion(inv *hcommon.InvocationInfo, chatReq *chat.ChatCompleti }) toolCalls := choice.Message.ToolCalls - log.Infof("Tool calls: %d", len(toolCalls)) + log.Debugf("Tool calls amount: %d", len(toolCalls)) for _, toolCall := range toolCalls { log.Infof("Tool call: %s", toolCall.Function.Name) argsStr := toolCall.Function.Arguments diff --git a/spearlet/hostcalls/common/models.go b/spearlet/hostcalls/common/models.go index f1a743b..568f7ed 100644 --- a/spearlet/hostcalls/common/models.go +++ b/spearlet/hostcalls/common/models.go @@ -2,7 +2,6 @@ package common import ( "os" - "strings" log "github.com/sirupsen/logrus" ) @@ -10,11 +9,12 @@ import ( type OpenAIFunctionType int type APIEndpointInfo struct { - Name string - Model string - Base *string - APIKey string - Url string + Name string + Model string + Base *string + APIKey string // used if APIKeyInEnv is empty + APIKeyInEnv string // if not empty, the API key is in env var + Url string } const ( @@ -42,18 +42,18 @@ var ( APIEndpointMap = map[OpenAIFunctionType][]APIEndpointInfo{ OpenAIFunctionTypeChatWithTools: { { - Name: "deepseek-toolchat", - Model: "deepseek-chat", - Base: &DeepSeekBase, - APIKey: os.Getenv("DEEPSEEK_API_KEY"), - Url: "/chat/completions", + Name: "deepseek-toolchat", + Model: "deepseek-chat", + Base: &DeepSeekBase, + APIKeyInEnv: "DEEPSEEK_API_KEY", + Url: "/chat/completions", }, { - Name: "openai-toolchat", - Model: "gpt-4o", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/chat/completions", + Name: "openai-toolchat", + Model: "gpt-4o", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/chat/completions", }, { Name: "qwen-toolchat-72b", @@ -93,11 +93,11 @@ var ( }, OpenAIFunctionTypeChatOnly: { { - Name: "openai-chat", - Model: "gpt-4o", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/chat/completions"}, + Name: "openai-chat", + Model: "gpt-4o", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/chat/completions"}, { Name: "llama-chat", Model: "llama", @@ -108,11 +108,11 @@ var ( }, OpenAIFunctionTypeEmbeddings: { { - Name: "openai-embed", - Model: "text-embedding-ada-002", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/embeddings", + Name: "openai-embed", + Model: "text-embedding-ada-002", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/embeddings", }, { Name: "nomic-embed", @@ -124,20 +124,20 @@ var ( }, OpenAIFunctionTypeTextToSpeech: { { - Name: "openai-tts", - Model: "tts-1", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/audio/speech", + Name: "openai-tts", + Model: "tts-1", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/audio/speech", }, }, OpenAIFunctionTypeImageGeneration: { { - Name: "openai-genimage", - Model: "dall-e-3", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/images/generations", + Name: "openai-genimage", + Model: "dall-e-3", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/images/generations", }, }, OpenAIFunctionTypeSpeechToText: { @@ -149,11 +149,11 @@ var ( Url: "/audio/transcriptions", }, { - Name: "openai-whisper", - Model: "whisper-1", - Base: &OpenAIBase, - APIKey: os.Getenv("OPENAI_API_KEY"), - Url: "/audio/transcriptions", + Name: "openai-whisper", + Model: "whisper-1", + Base: &OpenAIBase, + APIKeyInEnv: "OPENAI_API_KEY", + Url: "/audio/transcriptions", }, }, } @@ -169,40 +169,48 @@ func GetAPIEndpointInfo(ft OpenAIFunctionType, modelOrName string) []APIEndpoint res = append(res, info) } } - tmpList := make([]APIEndpointInfo, 0) + + // remove if the api key is from env but not set + res2 := make([]APIEndpointInfo, 0) for _, e := range res { - tmp := &APIEndpointInfo{ - Name: e.Name, - Model: e.Model, - Base: e.Base, - APIKey: "********", - Url: e.Url, - } - if e.APIKey == "" { - tmp.APIKey = "" + if e.APIKeyInEnv != "" { + key := os.Getenv(e.APIKeyInEnv) + if key == "" { + // skip if the key is not set + continue + } + res2 = append(res2, e) + res2[len(res2)-1].APIKey = key + } else { + res2 = append(res2, e) } - tmpList = append(tmpList, *tmp) } - log.Infof("Found %d endpoint(s) for %s: %v", len(tmpList), modelOrName, tmpList) - return res -} -func init() { - if os.Getenv("OPENAI_API_KEY") == "" { - log.Warnf("OPENAI_API_KEY not set, disabling openai functions") - newAPIEndpointMap := map[OpenAIFunctionType][]APIEndpointInfo{} - for ft, infoList := range APIEndpointMap { - newInfoList := []APIEndpointInfo{} - for _, info := range infoList { - // copy data only if the name does not contain "openai" - if !strings.Contains(info.Name, "openai") { - newInfoList = append(newInfoList, info) - } + func() { + // print the endpoint info found + tmpList := make([]APIEndpointInfo, 0) + for _, e := range res2 { + tmp := &APIEndpointInfo{ + Name: e.Name, + Model: e.Model, + Base: e.Base, + APIKey: "********", + Url: e.Url, } - newAPIEndpointMap[ft] = newInfoList + if e.APIKey == "" { + tmp.APIKey = "" + } + tmpList = append(tmpList, *tmp) } - APIEndpointMap = newAPIEndpointMap - } else { - OpenAIBase = "https://api.openai.com/v1" + log.Infof("Found %d endpoint(s) for %s: %v", len(tmpList), modelOrName, tmpList) + }() + + return res2 +} + +func init() { + if os.Getenv("OPENAI_API_BASE") != "" { + // official "https://api.openai.com/v1" + OpenAIBase = os.Getenv("OPENAI_API_BASE") } } diff --git a/spearlet/hostcalls/io.go b/spearlet/hostcalls/io.go index 4b7305d..bb0e27f 100644 --- a/spearlet/hostcalls/io.go +++ b/spearlet/hostcalls/io.go @@ -228,6 +228,11 @@ func Record(inv *hostcalls.InvocationInfo, args []byte) ([]byte, error) { // Wait for the user to press enter go func() { + if req.Dryrun() { + time.Sleep(3 * time.Second) + close(stopChan) + return + } _, _ = bufio.NewReader(os.Stdin).ReadBytes('\n') close(stopChan) }() diff --git a/spearlet/hostcalls/openai/openai_hc.go b/spearlet/hostcalls/openai/openai_hc.go index 575250f..9fc1a37 100644 --- a/spearlet/hostcalls/openai/openai_hc.go +++ b/spearlet/hostcalls/openai/openai_hc.go @@ -87,14 +87,14 @@ func OpenAIChatCompletion(ep common.APIEndpointInfo, chatReq *OpenAIChatCompleti // create a https request to https:///chat/completions and use b as the request body u := *ep.Base + ep.Url - log.Infof("URL: %s", u) - log.Infof("Request: %s", string(jsonBytes)) + log.Debugf("URL: %s", u) + log.Debugf("Request: %s", string(jsonBytes)) res, err := net.SendRequest(u, bytes.NewBuffer(jsonBytes), net.ContentTypeJSON, ep.APIKey) if err != nil { return nil, fmt.Errorf("error sending request: %v", err) } - log.Infof("Response: %s", string(res)) + log.Debugf("Response: %s", string(res)) respData := OpenAIChatCompletionResponse{} err = json.Unmarshal(res, &respData) if err != nil { diff --git a/spearlet/spearlet.go b/spearlet/spearlet.go index 7e6722f..f0e6747 100644 --- a/spearlet/spearlet.go +++ b/spearlet/spearlet.go @@ -22,6 +22,8 @@ import ( hostcalls "github.com/lfedgeai/spear/spearlet/hostcalls/common" "github.com/lfedgeai/spear/spearlet/task" _ "github.com/lfedgeai/spear/spearlet/tools" + + "github.com/docker/docker/client" ) var ( @@ -305,18 +307,118 @@ func (w *Spearlet) metaDataToTaskCfg(meta TaskMetaData) *task.TaskConfig { } } -func (w *Spearlet) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, - method string, data string) (string, error) { +func (w *Spearlet) ExecuteTaskNoMeta(funcName string, funcType task.TaskType, + wait bool, method string, data string) (string, error) { + var fakeMeta TaskMetaData + switch funcType { + case task.TaskTypeDocker: + // search if the docker image exists + // if not, return error + cli, err := client.NewClientWithOpts(client.FromEnv) + if err != nil { + return "", fmt.Errorf("error: %v", err) + } + + _, _, err = cli.ImageInspectWithRaw(context.Background(), funcName) + if err != nil { + return "", fmt.Errorf("error: %v", err) + } + + fakeMeta = TaskMetaData{ + Id: -1, + Type: task.TaskTypeDocker, + ImageName: funcName, + Name: funcName, + } + case task.TaskTypeProcess: + + fakeMeta = TaskMetaData{ + Id: -1, + Type: task.TaskTypeProcess, + ExecName: funcName, + Name: funcName, + } + case task.TaskTypeDylib: + panic("not implemented") + case task.TaskTypeWasm: + panic("not implemented") + default: + panic("invalid task type") + } + + log.Infof("Using metadata: %+v", fakeMeta) + + cfg := w.metaDataToTaskCfg(fakeMeta) + if cfg == nil { + return "", fmt.Errorf("error: invalid task type: %d", funcType) + } + rt, err := task.GetTaskRuntime(funcType) if err != nil { return "", fmt.Errorf("error: %v", err) } + newTask, err := rt.CreateTask(cfg) + if err != nil { + return "", fmt.Errorf("error: %v", err) + } + err = w.commMgr.InstallToTask(newTask) + if err != nil { + return "", fmt.Errorf("error: %v", err) + } + + log.Debugf("Starting task: %s", newTask.Name()) + newTask.Start() + + res := "" + builder := flatbuffers.NewBuilder(512) + methodOff := builder.CreateString(method) + dataOff := builder.CreateString(data) + custom.CustomRequestStart(builder) + custom.CustomRequestAddMethodStr(builder, methodOff) + custom.CustomRequestAddParamsStr(builder, dataOff) + builder.Finish(custom.CustomRequestEnd(builder)) + + if r, err := w.commMgr.SendOutgoingRPCRequest(newTask, transport.MethodCustom, + builder.FinishedBytes()); err != nil { + return "", fmt.Errorf("error: %v", err) + } else { + if len(r.ResponseBytes()) == 0 { + return "", nil // no response + } + customResp := custom.GetRootAsCustomResponse(r.ResponseBytes(), 0) + // marshal the result + if resTmp, err := json.Marshal(customResp.DataBytes()); err != nil { + return "", fmt.Errorf("error: %v", err) + } else { + res = string(resTmp) + } + } + + // terminate the task by sending a signal + if err := w.commMgr.SendOutgoingRPCSignal(newTask, transport.SignalTerminate, + []byte{}); err != nil { + return "", fmt.Errorf("error: %v", err) + } + + if wait { + // wait for the task to finish + newTask.Wait() + } + + return res, nil +} + +func (w *Spearlet) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, + method string, data string) (string, error) { // get metadata from taskId meta, ok := tmpMetaData[int(taskId)] if !ok { return "", fmt.Errorf("error: invalid task id: %d", taskId) } + if funcType == task.TaskTypeUnknown { + funcType = meta.Type + } if meta.Type != funcType { return "", fmt.Errorf("error: invalid task type: %d, %+v", funcType, meta) @@ -328,6 +430,12 @@ func (w *Spearlet) ExecuteTask(taskId int64, funcType task.TaskType, wait bool, if cfg == nil { return "", fmt.Errorf("error: invalid task type: %d", funcType) } + + rt, err := task.GetTaskRuntime(funcType) + if err != nil { + return "", fmt.Errorf("error: %v", err) + } + newTask, err := rt.CreateTask(cfg) if err != nil { return "", fmt.Errorf("error: %v", err) diff --git a/spearlet/task/proc_rt.go b/spearlet/task/proc_rt.go index a6bb459..ced7a6c 100644 --- a/spearlet/task/proc_rt.go +++ b/spearlet/task/proc_rt.go @@ -45,6 +45,10 @@ func (p *ProcessTaskRuntime) Start() error { } func (p *ProcessTaskRuntime) Stop() error { + // iterate all tasks and kill them + for _, task := range p.tasks { + task.Stop() + } return nil } @@ -57,7 +61,7 @@ func (p *ProcessTaskRuntime) CreateTask(cfg *TaskConfig) (Task, error) { task := NewProcessTask(cfg) - log.Infof("Command: %s %v", cfg.Cmd, cfg.Args) + log.Debugf("Command: %s %v", cfg.Cmd, cfg.Args) // execute the task cmd := exec.Command(cfg.Cmd, cfg.Args...) diff --git a/workload/docker/python/pytest-functionality/src/start.py b/workload/docker/python/pytest-functionality/src/start.py index 2849f17..810532f 100755 --- a/workload/docker/python/pytest-functionality/src/start.py +++ b/workload/docker/python/pytest-functionality/src/start.py @@ -6,9 +6,9 @@ import spear.client as client import spear.transform.chat as chat import spear.utils.io as io +from spear.utils.tool import register_internal_tool from spear.proto.tool import BuiltinToolID -from spear.utils.tool import register_internal_tool logging.basicConfig( level=logging.DEBUG, # Set the desired logging level @@ -23,14 +23,19 @@ agent = client.HostAgent() +TEST_LLM_MODEL = "gpt-4o" #"deepseek-toolchat" + def handle(params): """ handle the request """ logger.info("Handling request: %s", params) + logger.info("testing tool") + test_tool(TEST_LLM_MODEL) + logger.info("testing chat") - test_chat("gpt-4o") + test_chat(TEST_LLM_MODEL) logger.info("testing speak") test_speak("tts-1") @@ -41,8 +46,6 @@ def handle(params): logger.info("testing input") test_input() - logger.info("testing tool") - test_tool() # test("text-embedding-ada-002") # test("bge-large-en-v1.5") @@ -81,7 +84,7 @@ def test_record(model): """ logger.info("Testing model: %s", model) - resp = io.record(agent, "recording test") + resp = io.record(agent, "recording test", dryrun=True) assert resp is not None @@ -97,16 +100,17 @@ def test_input(): def test_tool_cb(param1, param2): """ - spear tool callback test function - - @param param1: test parameter 1 - @param param2: test parameter 2 + spear tool function for getting the sum of two numbers + + @param param1: first number + @param param2: second number """ logger.info("Testing tool callback %s %s", param1, param2) - return "test" + # parse params as int + return str(int(param1) + int(param2)) -def test_tool(): +def test_tool(model): """ test the model """ @@ -114,6 +118,17 @@ def test_tool(): tid = register_internal_tool(agent, test_tool_cb) logger.info("Registered tool: %d", tid) + resp = chat.chat(agent, "hi", model=model) + logger.info(resp) + resp = chat.chat(agent, "what is sum of 123 and 456?", + model=model, builtin_tools=[ + BuiltinToolID.BuiltinToolID.Datetime, + ], + internal_tools=[ + tid, + ]) + logger.info(resp) + if __name__ == "__main__": agent.register_handler("handle", handle) diff --git a/workload/process/python/pytest-functionality.py b/workload/process/python/pytest-functionality.py index 4e8a957..810532f 100755 --- a/workload/process/python/pytest-functionality.py +++ b/workload/process/python/pytest-functionality.py @@ -6,9 +6,9 @@ import spear.client as client import spear.transform.chat as chat import spear.utils.io as io +from spear.utils.tool import register_internal_tool from spear.proto.tool import BuiltinToolID -from spear.utils.tool import register_internal_tool logging.basicConfig( level=logging.DEBUG, # Set the desired logging level @@ -23,6 +23,8 @@ agent = client.HostAgent() +TEST_LLM_MODEL = "gpt-4o" #"deepseek-toolchat" + def handle(params): """ handle the request @@ -30,10 +32,10 @@ def handle(params): logger.info("Handling request: %s", params) logger.info("testing tool") - test_tool("deepseek-toolchat") + test_tool(TEST_LLM_MODEL) logger.info("testing chat") - test_chat("deepseek-toolchat") # "gpt-4o") + test_chat(TEST_LLM_MODEL) logger.info("testing speak") test_speak("tts-1") @@ -82,7 +84,7 @@ def test_record(model): """ logger.info("Testing model: %s", model) - resp = io.record(agent, "recording test") + resp = io.record(agent, "recording test", dryrun=True) assert resp is not None @@ -104,7 +106,8 @@ def test_tool_cb(param1, param2): @param param2: second number """ logger.info("Testing tool callback %s %s", param1, param2) - return "test" + # parse params as int + return str(int(param1) + int(param2)) def test_tool(model):