diff --git a/.gitignore b/.gitignore
index f8a263aaa..793134d36 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,4 +28,5 @@ package-lock.json
/chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/src/main/resources/lib/*
/chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/lib/*
/lib
-/out/*
\ No newline at end of file
+/out/*
+/chat2db-gateway/target
diff --git a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx
index 592755979..997b8bd3d 100644
--- a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx
+++ b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx
@@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => {
};
const renderSelectTable = () => {
- const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props;
+ const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props;
const options = (tables || []).map((t) => ({ value: t, label: t }));
return (
onSelectTableSyncModel(v.target.value)}
- // value={syncTableModel}
- value={SyncModelType.MANUAL}
+ value={syncTableModel}
style={{ marginBottom: '8px' }}
>
- {/* 自动 */}
+ 自动
手动
diff --git a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx
index fc7e36403..6a50f80e9 100644
--- a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx
+++ b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx
@@ -184,7 +184,7 @@ const SelectBoundInfo = memo((props: IProps) => {
boundInfo.databaseName,
boundInfo.schemaName,
);
- setSelectedTables(tableNameListTemp.slice(0, 1));
+ //setSelectedTables(tableNameListTemp.slice(0, 1));
}
}, [allTableList, isActive]);
diff --git a/chat2db-gateway/pom.xml b/chat2db-gateway/pom.xml
new file mode 100644
index 000000000..68d8cbc0d
--- /dev/null
+++ b/chat2db-gateway/pom.xml
@@ -0,0 +1,84 @@
+
+
+ 4.0.0
+
+ com.hejianjun
+ chat2db-gateway
+ 0.0.1-SNAPSHOT
+ jar
+
+ chat2db-gateway
+ Project for chat2db-gateway
+
+
+ org.springframework.boot
+ spring-boot-starter-parent
+ 2.6.7
+
+
+
+
+ 11
+ 8.12.2
+ 2.0.1
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-web
+
+
+
+
+ co.elastic.clients
+ elasticsearch-java
+ 8.12.2
+
+
+ jakarta.json
+ jakarta.json-api
+ ${jakarta-json.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.12.3
+
+
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+
+
+ org.projectlombok
+ lombok
+ true
+
+
+ org.springframework.boot
+ spring-boot-starter-validation
+
+
+
+
+
+
+
+ org.springframework.boot
+ spring-boot-maven-plugin
+
+
+
+
+
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/Application.java b/chat2db-gateway/src/main/java/com/hejianjun/Application.java
new file mode 100644
index 000000000..23f0f58f3
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/Application.java
@@ -0,0 +1,18 @@
+package com.hejianjun;
+
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.boot.SpringApplication;
+import org.springframework.boot.autoconfigure.SpringBootApplication;
+
+@Slf4j
+@SpringBootApplication
+public class Application {
+ /**
+ * 主程序入口
+ * @param args 命令行参数
+ */
+ public static void main(String[] args) {
+ SpringApplication.run(Application.class, args);
+ }
+
+}
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java
new file mode 100644
index 000000000..82432390c
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java
@@ -0,0 +1,14 @@
+package com.hejianjun.bean;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+
+import java.math.BigDecimal;
+import java.util.List;
+
+@Data
+@AllArgsConstructor
+public class SchemaDocument {
+ private String schema;
+ private List
vector;
+}
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java
new file mode 100644
index 000000000..e3398b435
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java
@@ -0,0 +1,39 @@
+package com.hejianjun.bean;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.experimental.SuperBuilder;
+
+import javax.validation.constraints.NotNull;
+import java.math.BigDecimal;
+import java.util.List;
+
+/**
+ * 表结构请求
+ */
+@Data
+@SuperBuilder
+@NoArgsConstructor
+@AllArgsConstructor
+public class TableSchemaRequest {
+
+ // 数据源ID
+ @NotNull
+ private Long dataSourceId;
+ // 数据库名称
+ @NotNull
+ private String databaseName;
+ // API密钥
+ private String apiKey;
+ // 数据源模式
+ private String dataSourceSchema;
+ // 模式向量
+ @NotNull
+ private List> schemaVector;
+ // 模式列表
+ @NotNull
+ private List schemaList;
+ // 插入前删除
+ private Boolean deleteBeforeInsert = false;
+}
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java
new file mode 100644
index 000000000..c95bc67f5
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java
@@ -0,0 +1,41 @@
+package com.hejianjun.config;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.json.jackson.JacksonJsonpMapper;
+import co.elastic.clients.transport.ElasticsearchTransport;
+import co.elastic.clients.transport.rest_client.RestClientTransport;
+import org.apache.http.Header;
+import org.apache.http.HttpHost;
+import org.apache.http.message.BasicHeader;
+import org.elasticsearch.client.RestClient;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+@Configuration
+public class ElasticsearchClientConfig {
+
+ String apiKey = "0E9NGIy7gb8a3TDVM8dC";
+
+ /**
+ * 创建ElasticsearchClient实例
+ *
+ * @return ElasticsearchClient实例
+ */
+ @Bean
+ public ElasticsearchClient elasticsearchClient() {
+ // 初始化低级客户端
+ RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200))
+ .setDefaultHeaders(new Header[]{
+ new BasicHeader("Authorization", "ApiKey " + apiKey)
+ })
+ .build();
+
+ // 使用低级客户端创建传输层
+ ElasticsearchTransport transport = new RestClientTransport(
+ restClient, new JacksonJsonpMapper());
+
+ // 创建ElasticsearchClient实例
+ return new ElasticsearchClient(transport);
+ }
+
+}
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java
new file mode 100644
index 000000000..5abb24133
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java
@@ -0,0 +1,56 @@
+package com.hejianjun.controller;
+
+import com.hejianjun.bean.TableSchemaRequest;
+import com.hejianjun.service.TableSchemaService;
+import lombok.AllArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.ResponseEntity;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+
+import java.io.IOException;
+import java.util.List;
+
+@Slf4j
+@RestController
+@AllArgsConstructor
+@RequestMapping("/api/client/milvus")
+public class TableSchemaController {
+
+ private final TableSchemaService service;
+
+
+ /**
+ * 保存表结构
+ * @param request 表结构请求对象
+ * @return 保存成功的文档ID
+ */
+ @PostMapping("/schema/save")
+ public ResponseEntity> saveSchema(@RequestBody TableSchemaRequest request) {
+ try {
+ List documentId = service.saveSchemaBatch(request);
+ return ResponseEntity.ok(documentId);
+ } catch (IOException e) {
+ log.error("保存表结构时发生错误", e);
+ return ResponseEntity.internalServerError().build();
+ }
+ }
+
+ /**
+ * 通过向量搜索表结构
+ * @param request 表结构搜索请求
+ * @return 搜索结果列表
+ */
+ @PostMapping("/schema/search")
+ public ResponseEntity searchByVector(@RequestBody TableSchemaRequest request) {
+ try {
+ TableSchemaRequest tableSchemaRequest = service.searchByVector(request);
+ return ResponseEntity.ok(tableSchemaRequest);
+ } catch (IOException e) {
+ log.error("Error searching schema", e);
+ return ResponseEntity.internalServerError().build();
+ }
+ }
+}
diff --git a/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java
new file mode 100644
index 000000000..3668f3358
--- /dev/null
+++ b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java
@@ -0,0 +1,105 @@
+package com.hejianjun.service;
+
+import co.elastic.clients.elasticsearch.ElasticsearchClient;
+import co.elastic.clients.elasticsearch.core.BulkRequest;
+import co.elastic.clients.elasticsearch.core.BulkResponse;
+import co.elastic.clients.elasticsearch.core.SearchResponse;
+import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
+import co.elastic.clients.elasticsearch.core.search.Hit;
+import com.hejianjun.bean.SchemaDocument;
+import com.hejianjun.bean.TableSchemaRequest;
+import lombok.AllArgsConstructor;
+import org.springframework.stereotype.Service;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * TableSchemaService类用于处理表结构相关的操作。
+ */
+@Service
+@AllArgsConstructor
+public class TableSchemaService {
+
+ private final ElasticsearchClient client;
+
+ /**
+ * 批量保存表结构。
+ *
+ * @param request 表结构请求对象
+ * @return 保存成功后的每个文档的ID列表
+ * @throws IOException IO异常
+ */
+ public List saveSchemaBatch(TableSchemaRequest request) throws IOException {
+ List documentIds = new ArrayList<>();
+
+ // 构建批量请求
+ BulkRequest.Builder bulkBuilder = new BulkRequest.Builder();
+
+ String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema();
+
+ for (int i = 0; i < request.getSchemaVector().size(); i++) {
+ // 假设schemaVector和schemaList的长度相同,并且一一对应
+ List vector = request.getSchemaVector().get(i);
+ String schema = request.getSchemaList().get(i);
+
+ // 创建文档内容,这里简化为Map,具体结构根据需求定义
+ SchemaDocument document = new SchemaDocument(schema,vector);
+
+ // 添加到批量请求
+ bulkBuilder.operations(op -> op
+ .index(idx -> idx
+ .index(indexName)
+ .document(document)
+ )
+ );
+ }
+
+ // 执行批量请求
+ BulkResponse bulkResponse = client.bulk(bulkBuilder.build());
+
+ // 收集文档ID
+ for (BulkResponseItem item : bulkResponse.items()) {
+ if (item.error()!=null) {
+ throw new IOException("Error indexing document: " + item.error().reason());
+ }
+ documentIds.add(item.id());
+ }
+
+ return documentIds;
+ }
+
+ /**
+ * 根据向量搜索表结构。
+ *
+ * @param request 表结构请求对象
+ * @return 搜索结果列表
+ * @throws IOException IO异常
+ */
+ public TableSchemaRequest searchByVector(TableSchemaRequest request) throws IOException {
+ String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema();
+ List vector = request.getSchemaVector().get(0);
+ // 假设schemaVector已转换为适合Elasticsearch的格式
+ // 执行k-NN搜索
+ SearchResponse response = client.search(s -> s
+ .index(indexName)
+ // 这里添加k-NN查询逻辑,具体实现根据实际需求
+ , SchemaDocument.class
+ );
+ List> schemaVector = new ArrayList<>();
+ List schemaList = new ArrayList<>();
+ List> hits = response.hits().hits();
+ for (Hit hit: hits) {
+ SchemaDocument document = hit.source();
+ if(document!=null) {
+ schemaVector.add(document.getVector());
+ schemaList.add(document.getSchema());
+ }
+ }
+ request.setSchemaVector(schemaVector);
+ request.setSchemaList(schemaList);
+ return request;
+ }
+}
diff --git a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java
index 40a291955..d08cc4a6a 100644
--- a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java
+++ b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java
@@ -33,8 +33,13 @@ public List databases(Connection connection) {
@Override
public String tableDDL(Connection connection, @NotEmpty String databaseName, String schemaName,
@NotEmpty String tableName) {
- String sql = "SHOW CREATE TABLE " + format(databaseName) + "."
- + format(tableName);
+ String sql;
+ if(StringUtils.isEmpty(databaseName)) {
+ sql = "SHOW CREATE TABLE " + format(tableName);
+ }else{
+ sql = "SHOW CREATE TABLE " + format(databaseName) + "."
+ + format(tableName);
+ }
return SQLExecutor.getInstance().execute(connection, sql, resultSet -> {
if (resultSet.next()) {
return resultSet.getString("Create Table");
diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java
index 1454772d8..46cf1e033 100644
--- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java
+++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java
@@ -339,6 +339,16 @@ public PageResult pageQuery(TablePageQueryParam param, TableSelector sele
t.setComment(tableCacheDO.getExtendInfo());
t.setSchemaName(tableCacheDO.getSchemaName());
t.setDatabaseName(tableCacheDO.getDatabaseName());
+ if(Boolean.TRUE.equals(selector.getColumnList())){
+ TableQueryParam tableQueryParam = new TableQueryParam();
+ tableQueryParam.setDataSourceId(param.getDataSourceId());
+ tableQueryParam.setDatabaseName(param.getDatabaseName());
+ tableQueryParam.setSchemaName(param.getSchemaName());
+ tableQueryParam.setTableName(tableCacheDO.getTableName());
+ tableQueryParam.setRefresh(false);
+ List columns = queryColumns(tableQueryParam);
+ t.setColumnList(columns);
+ }
tables.add(t);
}
}
@@ -419,44 +429,41 @@ public ListResult queryTables(TablePageQueryParam param) {
private long addDBCache(Long dataSourceId, String databaseName, String schemaName, long version) {
String key = getTableKey(dataSourceId, databaseName, schemaName);
-
Connection connection = Chat2DBContext.getConnection();
long n = 0;
- try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, null,
- new String[]{"TABLE", "SYSTEM TABLE"})) {
- List cacheDOS = new ArrayList<>();
- while (resultSet.next()) {
- TableCacheDO tableCacheDO = new TableCacheDO();
- tableCacheDO.setDatabaseName(databaseName);
- tableCacheDO.setSchemaName(schemaName);
- tableCacheDO.setTableName(resultSet.getString("TABLE_NAME"));
- tableCacheDO.setExtendInfo(resultSet.getString("REMARKS"));
- tableCacheDO.setDataSourceId(dataSourceId);
- tableCacheDO.setVersion(version);
- tableCacheDO.setKey(key);
- cacheDOS.add(tableCacheDO);
- if (cacheDOS.size() >= 500) {
- getTableCacheMapper().batchInsert(cacheDOS);
- cacheDOS = new ArrayList<>();
- }
- n++;
- }
- if (!CollectionUtils.isEmpty(cacheDOS)) {
+ MetaData metaSchema = Chat2DBContext.getMetaData();
+ List tables = metaSchema.tables(connection, databaseName, schemaName, null);
+ List cacheDOS = new ArrayList<>();
+ for(Table table : tables){
+ TableCacheDO tableCacheDO = new TableCacheDO();
+ tableCacheDO.setDatabaseName(databaseName);
+ tableCacheDO.setSchemaName(schemaName);
+ tableCacheDO.setTableName(table.getName());
+ tableCacheDO.setExtendInfo(table.getComment());
+ tableCacheDO.setDataSourceId(dataSourceId);
+ tableCacheDO.setVersion(version);
+ tableCacheDO.setKey(key);
+ metaSchema.columns(connection, databaseName, schemaName, table.getName());
+ cacheDOS.add(tableCacheDO);
+ if (cacheDOS.size() >= 500) {
getTableCacheMapper().batchInsert(cacheDOS);
+ cacheDOS = new ArrayList<>();
}
- LambdaQueryWrapper q = new LambdaQueryWrapper();
- q.eq(TableCacheDO::getDataSourceId, dataSourceId);
- q.lt(TableCacheDO::getVersion, version);
- if (StringUtils.isNotBlank(databaseName)) {
- q.eq(TableCacheDO::getDatabaseName, databaseName);
- }
- if (StringUtils.isNotBlank(schemaName)) {
- q.eq(TableCacheDO::getSchemaName, schemaName);
- }
- getTableCacheMapper().delete(q);
- } catch (SQLException e) {
- throw new RuntimeException(e);
+ n++;
+ }
+ if (!CollectionUtils.isEmpty(cacheDOS)) {
+ getTableCacheMapper().batchInsert(cacheDOS);
+ }
+ LambdaQueryWrapper q = new LambdaQueryWrapper();
+ q.eq(TableCacheDO::getDataSourceId, dataSourceId);
+ q.lt(TableCacheDO::getVersion, version);
+ if (StringUtils.isNotBlank(databaseName)) {
+ q.eq(TableCacheDO::getDatabaseName, databaseName);
+ }
+ if (StringUtils.isNotBlank(schemaName)) {
+ q.eq(TableCacheDO::getSchemaName, schemaName);
}
+ getTableCacheMapper().delete(q);
return n;
}
@@ -480,7 +487,7 @@ private Long getLock(Long dataSourceId, String databaseName, String schemaName,
}
} else {
long version = versionDO.getVersion() + 1;
- LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper();
+ LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>();
queryWrapper.eq(TableCacheVersionDO::getId, versionDO.getId());
queryWrapper.eq(TableCacheVersionDO::getVersion, versionDO.getVersion());
versionDO.setVersion(version);
diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml
index c367e2605..37efbd21c 100644
--- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml
+++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml
@@ -25,7 +25,7 @@
and tc.schema_name = #{schemaName}
- and LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%'))
+ and (LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%')) or tc.extend_info like concat('%',#{searchKey},'%'))
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java
new file mode 100644
index 000000000..01e4c0eef
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java
@@ -0,0 +1,33 @@
+package ai.chat2db.server.web.api.aspect;
+
+
+import org.aspectj.lang.ProceedingJoinPoint;
+import org.aspectj.lang.annotation.Around;
+import org.aspectj.lang.annotation.Aspect;
+import org.aspectj.lang.annotation.Pointcut;
+import org.springframework.stereotype.Component;
+
+@Aspect
+@Component
+public class GatewayClientServiceAspect {
+ /**
+ * 定义切点,匹配 GatewayClientService 类中的所有方法
+ */
+ @Pointcut("execution(* ai.chat2db.server.web.api.http.GatewayClientService.*(..)) && !execution(* ai.chat2db.server.web.api.http.GatewayClientService.checkInWhite(..))")
+ public void gatewayClientServiceMethods() {}
+
+
+
+ /**
+ * 环绕通知:在切点方法执行时触发
+ * @param joinPoint
+ * @return
+ * @throws Throwable
+ */
+ @Around("gatewayClientServiceMethods()")
+ public Object aroundGatewayClientServiceMethods(ProceedingJoinPoint joinPoint) throws Throwable {
+ // 这里你可以执行一些自定义的逻辑,如果需要的话
+ // 然后返回 null 或其他默认值
+ return null;
+ }
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
index c9e77806f..8a4ca74eb 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java
@@ -1,20 +1,17 @@
package ai.chat2db.server.web.api.controller.ai;
+
+
import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
import ai.chat2db.server.domain.api.model.Config;
-import ai.chat2db.server.domain.api.model.DataSource;
-import ai.chat2db.server.domain.api.param.ShowCreateTableParam;
-import ai.chat2db.server.domain.api.param.TableQueryParam;
import ai.chat2db.server.domain.api.service.ConfigService;
-import ai.chat2db.server.domain.api.service.DataSourceService;
-import ai.chat2db.server.domain.api.service.TableService;
-import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum;
-import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
-import ai.chat2db.server.tools.common.util.EasyEnumUtils;
+import ai.chat2db.server.tools.common.model.LoginUser;
+import ai.chat2db.server.tools.common.util.ContextUtils;
import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient;
@@ -26,11 +23,9 @@
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage;
import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
-import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
-import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
-import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
import ai.chat2db.server.web.api.controller.ai.fastchat.listener.FastChatAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole;
import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
@@ -41,42 +36,36 @@
import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.tongyi.client.TongyiChatAIClient;
import ai.chat2db.server.web.api.controller.ai.tongyi.listener.TongyiChatAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.utils.PromptService;
import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient;
import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener;
import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient;
import ai.chat2db.server.web.api.controller.ai.zhipu.listener.ZhipuChatAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions;
import ai.chat2db.server.web.api.http.GatewayClientService;
-import ai.chat2db.server.web.api.http.model.EsTableSchema;
-import ai.chat2db.server.web.api.http.model.TableSchema;
-import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest;
-import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
-import ai.chat2db.server.web.api.http.request.WhiteListRequest;
-import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse;
-import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
-import com.alibaba.fastjson2.JSON;
import com.google.common.collect.Lists;
+import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
+import com.unfbx.chatgpt.entity.chat.tool.Tools;
+import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
-import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
-import java.math.BigDecimal;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.stream.Collectors;
/**
* 描述:
@@ -90,14 +79,6 @@
@Slf4j
public class ChatController {
- @Autowired
- private TableService tableService;
-
- @Autowired
- private ChatConverter chatConverter;
-
- @Autowired
- private DataSourceService dataSourceService;
@Value("${chatgpt.context.length}")
private Integer contextLength;
@@ -108,6 +89,10 @@ public class ChatController {
@Resource
private GatewayClientService gatewayClientService;
+
+ @Resource
+ protected PromptService promptService;
+
/**
* chat的超时时间
*/
@@ -171,7 +156,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc
/**
* 自定义模型非流式输出接口DEMO
*
- * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
+ * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致
*
*
* @param queryRequest
@@ -262,7 +247,7 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE
*/
private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
- RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
+ RestAIClient.getInstance().restCompletions(promptService.buildPrompt(prompt), eventSourceListener);
return sseEmitter;
}
@@ -276,11 +261,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter
* @throws IOException
*/
private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
- throws IOException {
- String prompt = buildPrompt(queryRequest);
+ throws IOException {
+ String prompt = promptService.buildAutoPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
- prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
+ prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
}
@@ -290,9 +275,17 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
-
- OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
- OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener);
+ LoginUser loginUser = ContextUtils.getLoginUser();
+ OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, promptService, queryRequest,loginUser);
+ ChatCompletion chatCompletion = ChatCompletion.builder()
+ .messages(messages).stream(true).build();
+ if(queryRequest.getDatabaseName()!=null){
+ ToolsFunction function = PromptService.getToolsFunction();
+ chatCompletion.setModel("gpt-3.5-turbo-0125");
+ chatCompletion.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)));
+ chatCompletion.setToolChoice("auto");
+ }
+ OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener);
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
return sseEmitter;
}
@@ -308,7 +301,7 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
*/
private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
- String prompt = buildPrompt(queryRequest);
+ String prompt = promptService.buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("exceed max token length:{},input length:{}", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
@@ -338,11 +331,13 @@ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter s
* @throws IOException
*/
private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
+ String prompt = promptService.buildAutoPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
prompt.length() / TOKEN_CONVERT_CHAR_LENGTH);
throw new ParamBusinessException();
+ }else{
+ log.info("提示词 :{}",prompt);
}
List messages = (List)LocalCache.CACHE.get(uid);
if (CollectionUtils.isNotEmpty(messages)) {
@@ -356,9 +351,16 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
-
- AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter);
- AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener);
+ LoginUser loginUser = ContextUtils.getLoginUser();
+ AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser);
+ AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages);
+ chatCompletionsOptions.setStream(true);
+ if(queryRequest.getDatabaseName()!=null){
+ ToolsFunction function = PromptService.getToolsFunction();
+ chatCompletionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)));
+ chatCompletionsOptions.setToolChoice("auto");
+ }
+ AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
@@ -373,8 +375,8 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse
* @throws IOException
*/
private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
- List messages = getFastChatMessage(uid, prompt);
+ String prompt = promptService.buildPrompt(queryRequest);
+ List messages = promptService.getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
@@ -394,13 +396,27 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter
* @throws IOException
*/
private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
- List messages = getFastChatMessage(uid, prompt);
+ String prompt = promptService.buildAutoPrompt(queryRequest);
+ log.info("原始提示词{}",prompt);
+ List messages = promptService.getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
-
- ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter);
- ZhipuChatAIClient.getInstance().streamCompletions(messages, sourceListener);
+ LoginUser loginUser = ContextUtils.getLoginUser();
+ ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser);
+ String requestId = String.valueOf(System.currentTimeMillis());
+ // 建议直接查看demo包代码,这里更新可能不及时
+ ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder()
+ .requestId(requestId)
+ .stream(true)
+
+ .messages(messages)
+ .build();
+ if(queryRequest.getDatabaseName()!=null){
+ ToolsFunction function = PromptService.getToolsFunction();
+ completionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)));
+ completionsOptions.setToolChoice("auto");
+ }
+ ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, sourceListener);
LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT);
return sseEmitter;
}
@@ -415,8 +431,8 @@ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter
* @throws IOException
*/
private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
- List messages = getFastChatMessage(uid, prompt);
+ String prompt = promptService.buildPrompt(queryRequest);
+ List messages = promptService.getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
@@ -436,8 +452,8 @@ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitte
* @throws IOException
*/
private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
- List messages = getFastChatMessage(uid, prompt);
+ String prompt = promptService.buildPrompt(queryRequest);
+ List messages = promptService.getFastChatMessage(uid, prompt);
buildSseEmitter(sseEmitter, uid);
@@ -447,26 +463,7 @@ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter
return sseEmitter;
}
- /**
- * get fast chat message
- *
- * @param uid
- * @param prompt
- * @return
- */
- private List getFastChatMessage(String uid, String prompt) {
- List messages = (List)LocalCache.CACHE.get(uid);
- if (CollectionUtils.isNotEmpty(messages)) {
- if (messages.size() >= contextLength) {
- messages = messages.subList(1, contextLength);
- }
- } else {
- messages = Lists.newArrayList();
- }
- FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt);
- messages.add(currentMessage);
- return messages;
- }
+
/**
* chat with wenxin chat openai
@@ -478,8 +475,8 @@ private List getFastChatMessage(String uid, String prompt) {
* @throws IOException
*/
private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
- List messages = getFastChatMessage(uid, prompt);
+ String prompt = promptService.buildPrompt(queryRequest);
+ List messages = promptService.getFastChatMessage(uid, prompt);
if (messages.size() >= 2 && messages.size() % 2 == 0) {
messages.remove(messages.size() - 1);
}
@@ -503,7 +500,7 @@ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter ss
* @throws IOException
*/
private SseEmitter chatWithClaudeAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
- String prompt = buildPrompt(queryRequest);
+ String prompt = promptService.buildPrompt(queryRequest);
ClaudeChatMessage claudeChatMessage = new ClaudeChatMessage();
claudeChatMessage.setText(prompt);
ClaudeChatCompletionsOptions chatCompletionsOptions = new ClaudeChatCompletionsOptions();
@@ -546,270 +543,5 @@ private SseEmitter buildSseEmitter(SseEmitter sseEmitter, String uid) throws IOE
return sseEmitter;
}
- /**
- * 构建schema参数
- *
- * @param tableQueryParam
- * @param tableNames
- * @return
- */
- private String buildTableColumn(TableQueryParam tableQueryParam,
- List tableNames) {
- if (CollectionUtils.isEmpty(tableNames)) {
- return "";
- }
- List schemaContent = Lists.newArrayList();
- try {
- schemaContent = tableNames.stream().map(tableName -> {
- tableQueryParam.setTableName(tableName);
- return queryTableDdl(tableName, tableQueryParam);
- }).collect(Collectors.toList());
- } catch (Exception exception) {
- log.error("query table error, do nothing");
- }
-
- return JSON.toJSONString(schemaContent);
- }
-
- /**
- * query table schema
- *
- * @param tableName
- * @param request
- * @return
- */
- private String queryTableDdl(String tableName, TableQueryParam request) {
- ShowCreateTableParam param = new ShowCreateTableParam();
- param.setTableName(tableName);
- param.setDataSourceId(request.getDataSourceId());
- param.setDatabaseName(request.getDatabaseName());
- param.setSchemaName(request.getSchemaName());
- DataResult tableSchema = tableService.showCreateTable(param);
- return tableSchema.getData();
- }
-
- /**
- * 构建prompt
- *
- * @param queryRequest
- * @return
- */
- private String buildPrompt(ChatQueryRequest queryRequest) {
- if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) {
- return queryRequest.getMessage();
- }
-
- // 查询schema信息
- String dataSourceType = queryDatabaseType(queryRequest);
- String properties = "";
- if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
- TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
- properties = buildTableColumn(queryParam, queryRequest.getTableNames());
- } else {
- properties = mappingDatabaseSchema(queryRequest);
- }
- String prompt = queryRequest.getMessage();
- String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
- : queryRequest.getPromptType();
- PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
- String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
- String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
- "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
- + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
- properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
- pType.getDescription(), ext, prompt);
- switch (pType) {
- case SQL_2_SQL:
- schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
- "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
- "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
- default:
- break;
- }
- String cleanedInput = schemaProperty.replaceAll("[\r\t]", "");
- return cleanedInput;
- }
-
- /**
- * query chat2db apikey
- *
- * @return
- */
- public String getApiKey() {
- ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
- Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
- String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
- // only sync for chat2db ai
- if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
- return null;
- }
- Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
- if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
- return null;
- }
- return keyConfig.getContent();
- }
-
- /**
- * query database type
- *
- * @param queryRequest
- * @return
- */
- public String queryDatabaseType(ChatQueryRequest queryRequest) {
- // 查询schema信息
- DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
- String dataSourceType = dataResult.getData().getType();
- if (StringUtils.isBlank(dataSourceType)) {
- dataSourceType = "MYSQL";
- }
- return dataSourceType;
- }
-
- public String mappingDatabaseSchema(ChatQueryRequest queryRequest) {
- String properties = "";
- String apiKey = getApiKey();
- if (StringUtils.isNotBlank(apiKey)) {
- boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
- if (res) {
-// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest);
- properties = queryDatabaseSchema(queryRequest);
- }
- }
- return properties;
- }
-
- /**
- * query database schema
- *
- * @param queryRequest
- * @return
- * @throws IOException
- */
- public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
- // request embedding
- FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
- List> contentVector = new ArrayList<>();
- if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) {
- return "";
- }
- contentVector.add(response.getData().get(0).getEmbedding());
-
- // search embedding
- TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
- tableSchemaRequest.setSchemaVector(contentVector);
- tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId());
- tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName());
- tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName());
- ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
- Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
- if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
- return "";
- }
- tableSchemaRequest.setApiKey(keyConfig.getContent());
- try {
- DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
- List schemas = Lists.newArrayList();
- if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
- for(TableSchema data: result.getData().getTableSchemas()){
- schemas.add(data.getTableSchema());
- }
- }
- if (CollectionUtils.isEmpty(schemas)) {
- return "";
- }
- String res = JSON.toJSONString(schemas);
- log.info("search vector result:{}", res);
- return res;
- } catch (Exception exception) {
- log.error("query table error, do nothing");
- return "";
- }
- }
-
- /**
- * query database schema
- *
- * @param queryRequest
- * @return
- * @throws IOException
- */
- public String querySchemaByEs(ChatQueryRequest queryRequest) {
- // search embedding
- EsTableSchemaRequest tableSchemaRequest = new EsTableSchemaRequest();
- tableSchemaRequest.setSearchKey(queryRequest.getMessage());
- tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId());
- tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName());
- tableSchemaRequest.setSchemaName(queryRequest.getSchemaName());
- ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
- Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
- if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
- return "";
- }
- tableSchemaRequest.setApiKey(keyConfig.getContent());
- try {
- DataResult result = gatewayClientService.schemaEsSearch(tableSchemaRequest);
- List schemas = Lists.newArrayList();
- if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
- for(EsTableSchema data: result.getData().getTableSchemas()){
- schemas.add(data.getTableSchemaContent());
- }
- }
- if (CollectionUtils.isEmpty(schemas)) {
- return "";
- }
- String res = JSON.toJSONString(schemas);
- log.info("search es result:{}", res);
- return res;
- } catch (Exception exception) {
- log.error("query es table error, do nothing");
- return "";
- }
- }
-
- /**
- * distribute embedding with different AI
- *
- * @return
- */
- public FastChatEmbeddingResponse distributeAIEmbedding(String input) {
- ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
- Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
- String aiSqlSource = config.getContent();
- if (Objects.isNull(aiSqlSource)) {
- return null;
- }
- AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
- switch (Objects.requireNonNull(aiSqlSourceEnum)) {
- case CHAT2DBAI:
- return embeddingWithChat2dbAi(input);
- case FASTCHATAI:
- return embeddingWithFastChatAi(input);
- }
- return null;
- }
-
- /**
- * embedding with fast chat openai
- *
- * @param input
- * @return
- * @throws IOException
- */
- private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) {
- FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input);
- return response;
- }
-
- /**
- * embedding with open ai
- *
- * @param input
- * @return
- */
- private FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) {
- FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input);
- return embeddings;
- }
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java
index c8c694309..70df31c6f 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java
@@ -242,7 +242,7 @@ public void syncTableVector(TableBriefQueryRequest param) throws Exception {
return;
}
- String apiKey = getApiKey();
+ String apiKey = promptService.getApiKey();
if (StringUtils.isBlank(apiKey)) {
return;
}
@@ -281,7 +281,7 @@ private void saveTableEmbedding(String tableSchema, TableSchemaRequest tableSche
List> contentVector = new ArrayList<>();
for(String str : schemaList){
// request embedding
- FastChatEmbeddingResponse response = distributeAIEmbedding(str);
+ FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str);
if(response == null){
throw new ParamBusinessException();
}
@@ -310,7 +310,7 @@ public void syncTableEs(TableBriefQueryRequest param) throws Exception {
return;
}
- String apiKey = getApiKey();
+ String apiKey = promptService.getApiKey();
if (StringUtils.isBlank(apiKey)) {
return;
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java
index 6ff16ee09..6ef0731ac 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java
@@ -70,7 +70,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request)
contentWordCount.add(str.length());
// request embedding
- FastChatEmbeddingResponse response = distributeAIEmbedding(str);
+ FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str);
if(response == null){
continue;
}
@@ -97,7 +97,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request)
public SseEmitter search(ChatQueryRequest queryRequest, @RequestHeader Map headers)
throws Exception {
// request embedding
- FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
+ FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(queryRequest.getMessage());
List> contentVector = new ArrayList<>();
contentVector.add(response.getData().get(0).getEmbedding());
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java
index 0c6180667..94caf7d4f 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java
@@ -63,8 +63,8 @@ public SseEmitter prompt(ChatQueryRequest queryRequest, @RequestHeader Map chatMessages, EventSourceListener eventSourceListener) {
- if (CollectionUtils.isEmpty(chatMessages)) {
- log.error("param error:Azure Prompt cannot be empty");
- throw new ParamBusinessException("prompt");
- }
+ public void streamCompletions(AzureChatCompletionsOptions chatCompletionsOptions, EventSourceListener eventSourceListener) {
if (Objects.isNull(eventSourceListener)) {
log.error("param error:AzureEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
- log.info("Azure Open AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
try {
-
- AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(chatMessages);
chatCompletionsOptions.setStream(true);
chatCompletionsOptions.setModel(this.deployId);
-
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
@@ -172,7 +164,7 @@ public void streamCompletions(List chatMessages, EventSourceLi
if (!endpoint.endsWith("/")) {
endpoint = endpoint + "/";
}
- String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2023-05-15";
+ String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2024-02-15-preview";
Request request = new Request.Builder()
.url(url)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java
index 4488bd6b8..2b9ab4a99 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java
@@ -1,15 +1,32 @@
package ai.chat2db.server.web.api.controller.ai.azure.listener;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import ai.chat2db.server.tools.common.model.LoginUser;
+import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatChoice;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletions;
+import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage;
+import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole;
import ai.chat2db.server.web.api.controller.ai.azure.model.AzureCompletionsUsage;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole;
+import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
+import ai.chat2db.server.web.api.controller.ai.utils.PromptService;
+import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient;
+import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions;
+
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
+import com.unfbx.chatgpt.entity.chat.tool.Tools;
+import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
+
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
@@ -26,123 +43,26 @@
* @date 2023-02-22
*/
@Slf4j
-public class AzureOpenAIEventSourceListener extends EventSourceListener {
+public class AzureOpenAIEventSourceListener extends OpenAIEventSourceListener {
- private SseEmitter sseEmitter;
- private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
-
- public AzureOpenAIEventSourceListener(SseEmitter sseEmitter) {
- this.sseEmitter = sseEmitter;
+ public AzureOpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService,
+ ChatQueryRequest queryRequest, LoginUser loginUser) {
+ super(sseEmitter, promptService, queryRequest, loginUser);
}
- /**
- * {@inheritDoc}
- */
@Override
- public void onOpen(EventSource eventSource, Response response) {
- log.info("AzureOpenAI建立sse连接...");
+ public String getName(){
+ return "AzureOpenAI";
}
- /**
- * {@inheritDoc}
- */
- @SneakyThrows
- @Override
- public void onEvent(EventSource eventSource, String id, String type, String data) {
- log.info("AzureOpenAI返回数据:{}", data);
- if (data.equals("[DONE]")) {
- log.info("AzureOpenAI返回数据结束了");
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]")
- .reconnectTime(3000));
- sseEmitter.complete();
- return;
- }
-
- AzureChatCompletions chatCompletions = mapper.readValue(data, AzureChatCompletions.class);
- String text = "";
- log.info("Model ID={} is created at {}.", chatCompletions.getId(),
- chatCompletions.getCreated());
- for (AzureChatChoice choice : chatCompletions.getChoices()) {
- AzureChatMessage message = choice.getDelta();
- if (message != null) {
- log.info("Index: {}, Chat Role: {}", choice.getIndex(), message.getRole());
- if (message.getContent() != null) {
- text = message.getContent();
- }
- }
- }
-
- AzureCompletionsUsage usage = chatCompletions.getUsage();
- if (usage != null) {
- log.info(
- "Usage: number of prompt token is {}, number of completion token is {}, and number of total "
- + "tokens in request and response is {}.%n", usage.getPromptTokens(),
- usage.getCompletionTokens(), usage.getTotalTokens());
- }
-
- Message message = new Message();
- message.setContent(text);
- sseEmitter.send(SseEmitter.event()
- .id(null)
- .data(message)
- .reconnectTime(3000));
- }
-
- @Override
- public void onClosed(EventSource eventSource) {
- try {
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- sseEmitter.complete();
- log.info("AzureOpenAI close sse connection...");
- }
-
- @Override
- public void onFailure(EventSource eventSource, Throwable t, Response response) {
- try {
- if (Objects.isNull(response)) {
- String message = t.getMessage();
- Message sseMessage = new Message();
- sseMessage.setContent(message);
- sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(sseMessage));
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- sseEmitter.complete();
- return;
- }
- ResponseBody body = response.body();
- String bodyString = Objects.nonNull(t) ? t.getMessage() : "";
- if (Objects.nonNull(body)) {
- bodyString = body.string();
- if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) {
- bodyString = t.getMessage();
- }
- log.error("Azure OpenAI sse response:{}", bodyString);
- } else {
- log.error("Azure OpenAI sse response:{},error:{}", response, t);
- }
- eventSource.cancel();
- Message message = new Message();
- message.setContent("Azure OpenAI error:" + bodyString);
- sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(message));
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- sseEmitter.complete();
- } catch (Exception exception) {
- log.error("Azure OpenAI发送数据异常:", exception);
- }
+ @Override
+ public void functionCall(String prompt){
+ AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt);
+ List messages = new ArrayList<>();
+ messages.add(currentMessage);
+ AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages);
+ chatCompletionsOptions.setStream(true);
+ AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, this);
}
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java
index 1d6198e57..8d33166b3 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java
@@ -7,6 +7,8 @@
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
+import com.unfbx.chatgpt.entity.chat.tool.Tools;
+
import lombok.Data;
/**
@@ -391,4 +393,11 @@ public AzureChatCompletionsOptions setModel(String model) {
this.model = model;
return this;
}
+
+ // 新添加的参数
+ @JsonProperty(value = "tool_choice")
+ private String toolChoice; // 工具选择策略
+
+ @JsonProperty(value = "tools")
+ private List tools; // 工具列表
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java
index 0f0b6d84f..295d39cff 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java
@@ -1,21 +1,15 @@
package ai.chat2db.server.web.api.controller.ai.chat2db.client;
-import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
-import ai.chat2db.server.domain.api.model.Config;
-import ai.chat2db.server.domain.api.service.ConfigService;
-import ai.chat2db.server.tools.base.wrapper.result.DataResult;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding;
import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
-import ai.chat2db.server.web.api.util.ApplicationContextUtil;
import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.ChatCompletion;
import com.unfbx.chatgpt.entity.chat.Message;
-import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.*;
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java
index 9e9745c75..f6e833a01 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java
@@ -38,6 +38,12 @@ public enum PromptType implements BaseEnum {
* text generation
*/
TEXT_GENERATION("文本生成"),
+
+
+ /**
+ * GET_TABLE_COLUMNS
+ */
+ GET_TABLE_COLUMNS("获取指定表的属性"),
;
final String description;
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java
index 9ebf711c2..1d3de3bc7 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java
@@ -4,6 +4,7 @@
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Objects;
+import java.util.concurrent.TimeUnit;
import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
@@ -93,7 +94,17 @@ public static void refresh() {
log.info("refresh openai apikey:{}", maskApiKey(apikey));
if (Objects.nonNull(host) && Objects.nonNull(port)) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port));
- OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build();
+ OkHttpClient okHttpClient = new OkHttpClient.Builder()
+ // 设置连接超时为10秒
+ .connectTimeout(10, TimeUnit.SECONDS)
+ // 设置读取超时为30秒
+ .readTimeout(30, TimeUnit.SECONDS)
+ // 设置写入超时为15秒
+ .writeTimeout(15, TimeUnit.SECONDS)
+ // 设置整个调用的超时为1分钟
+ .callTimeout(1, TimeUnit.MINUTES)
+ .proxy(proxy)
+ .build();
OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build();
} else {
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java
index ccadf6d68..54609e72f 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java
@@ -1,20 +1,39 @@
package ai.chat2db.server.web.api.controller.ai.openai.listener;
-import java.util.Objects;
-
+import ai.chat2db.server.domain.repository.Dbutils;
+import ai.chat2db.server.tools.common.model.Context;
+import ai.chat2db.server.tools.common.model.LoginUser;
+import ai.chat2db.server.tools.common.util.ContextUtils;
+import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
+import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse;
+import ai.chat2db.server.web.api.controller.ai.utils.PromptService;
+import com.alibaba.fastjson2.JSONArray;
+import com.alibaba.fastjson2.JSONObject;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.entity.chat.Message;
+import com.unfbx.chatgpt.entity.chat.tool.ToolCallFunction;
+import com.unfbx.chatgpt.entity.chat.tool.ToolCalls;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
+
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.List;
+import java.util.Objects;
+
/**
* 描述:OpenAIEventSourceListener
*
@@ -24,61 +43,203 @@
@Slf4j
public class OpenAIEventSourceListener extends EventSourceListener {
- private SseEmitter sseEmitter;
+ private final SseEmitter sseEmitter;
+
+ protected final PromptService promptService;;
+
+ private final ChatQueryRequest queryRequest;
- public OpenAIEventSourceListener(SseEmitter sseEmitter) {
+ public final LoginUser loginUser;
+
+ private List toolCalls = new ArrayList<>();
+
+
+ public OpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, ChatQueryRequest queryRequest, LoginUser loginUser) {
this.sseEmitter = sseEmitter;
+ this.promptService = promptService;
+ this.queryRequest = queryRequest;
+ this.loginUser = loginUser;
+ }
+
+ public static List mergeToolCallsLists(List list1, List list2) {
+ List mergedList = new ArrayList<>(list1);
+ if (list2.isEmpty()) {
+ return mergedList;
+ }
+ ToolCalls item2 = list2.get(0);
+ boolean isMerged = false;
+ // 反向遍历
+ for (int i = list1.size() - 1; i >= 0; i--) {
+ ToolCalls item1 = list1.get(i);
+ if (item2.getId() == null || Objects.equals(item1.getId(), item2.getId())) {
+ mergedList.set(i, mergeToolCalls(item1, item2));
+ isMerged = true;
+ break;
+ }
+ }
+ if (!isMerged) {
+ // 如果 list2 中的对象与 list1 中的任何对象都不匹配,则作为新对象添加
+ mergedList.add(item2);
+ }
+ return mergedList;
+ }
+
+ private static ToolCalls mergeToolCalls(ToolCalls tc1, ToolCalls tc2) {
+ if (tc1 == null) return tc2;
+ if (tc2 == null) return tc1;
+
+ // 相同的逻辑,只是当 id 为 null 时进行合并
+ String id = tc1.getId() != null ? tc1.getId() : tc2.getId();
+ String type = mergeStrings(tc1.getType(), tc2.getType());
+ ToolCallFunction function = mergeToolCallFunctions(tc1.getFunction(), tc2.getFunction());
+
+ return new ToolCalls(id, type, function);
+ }
+
+ private static ToolCallFunction mergeToolCallFunctions(ToolCallFunction f1, ToolCallFunction f2) {
+ if (f1 == null) return f2;
+ if (f2 == null) return f1;
+
+ String name = mergeStrings(f1.getName(), f2.getName());
+ String arguments = mergeStrings(f1.getArguments(), f2.getArguments());
+
+ return new ToolCallFunction(name, arguments);
+ }
+
+ private static String mergeStrings(String str1, String str2) {
+ if (str1 != null && str2 != null) {
+ // Concatenate both strings
+ return str1 + str2;
+ } else if (str1 != null) {
+ return str1;
+ } else {
+ return str2;
+ }
}
+
+ public String getName() {
+ return "OpenAI";
+ }
/**
* {@inheritDoc}
*/
@Override
public void onOpen(EventSource eventSource, Response response) {
- log.info("OpenAI建立sse连接...");
+ log.info("{}建立sse连接...",getName());
}
+
+ public void functionCall(String prompt){
+ List messages = new ArrayList<>();
+ Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build();
+ messages.add(currentMessage);
+ OpenAIClient.getInstance().streamChatCompletion(messages, this);
+ }
+
+
+ public void handleTableNames(Set tableNames, Object instance) {
+ if (instance instanceof JSONArray) {
+ ((JSONArray) instance).forEach(item -> handleTableNames(tableNames, item));
+ } else if (instance instanceof JSONObject) {
+ ((JSONObject) instance).forEach((key, value) -> handleTableNames(tableNames, value));
+ } else if (instance instanceof String) {
+ String tableName = (String) instance;
+ List queryTableNames = queryRequest.getTableNames();
+ if (queryTableNames != null) {
+ String mostSimilarTableName = queryTableNames.stream()
+ // 根据相似度排序
+ .min(Comparator.comparingInt(existingTableName -> StringUtils.getLevenshteinDistance(existingTableName, tableName)))
+ .orElse(tableName);
+ tableNames.add(mostSimilarTableName);
+ }else{
+ tableNames.add(tableName);
+ }
+
+ }
+ }
/**
* {@inheritDoc}
*/
@SneakyThrows
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
- log.info("OpenAI返回数据:{}", data);
+ String scheme = getName();
+ log.info("{}返回数据:{}",scheme,data);
if (data.equals("[DONE]")) {
- log.info("OpenAI返回数据结束了");
+ if (toolCalls.isEmpty()) {
+ log.info("{}返回数据结束了",scheme);
+ sseEmitter.send(SseEmitter.event()
+ .id("[DONE]")
+ .data("[DONE]")
+ .reconnectTime(3000));
+ sseEmitter.complete();
+ return;
+ }
+ Set tableNames = new HashSet<>();
+ for (ToolCalls toolCall : toolCalls) {
+ String callId = toolCall.getId();
+ ToolCallFunction function = toolCall.getFunction();
+ if (function != null && Objects.nonNull(function.getArguments())) {
+ String functionName = function.getName();
+ if ("get_table_columns".equals(functionName)) {
+ JSONObject arguments = JSONObject.parse(function.getArguments());
+ handleTableNames(tableNames,arguments.get("table_names"));
+ log.info("原始参数:{},处理后:{}",arguments,tableNames);
+ }
+ }
+ }
+ Message message = new Message();
+ message.setContent("选择表" + tableNames+"\n");
sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]")
- .reconnectTime(3000));
- sseEmitter.complete();
+ .data(message)
+ .reconnectTime(3000));
+ queryRequest.setTableNames(new ArrayList<>(tableNames));
+ ContextUtils.setContext(Context.builder()
+ .loginUser(loginUser)
+ .build());
+ Dbutils.setSession();
+ String prompt = promptService.buildPrompt(queryRequest);
+ Dbutils.removeSession();
+ prompt = prompt.replaceAll("#", "");
+ log.info("{} 新提示词 :{}",scheme,prompt);
+ functionCall(prompt);
+ toolCalls.clear();
return;
}
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
// 读取Json
ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class);
- String text = completionResponse.getChoices().get(0).getDelta() == null
- ? completionResponse.getChoices().get(0).getText()
- : completionResponse.getChoices().get(0).getDelta().getContent();
+ if(CollectionUtils.isEmpty(completionResponse.getChoices())){
+ return;
+ }
+ Message delta = completionResponse.getChoices().get(0).getDelta();
+ if (delta != null && delta.getToolCalls() != null) {
+ this.toolCalls = mergeToolCallsLists(this.toolCalls, delta.getToolCalls());
+ }
+ String text = delta == null
+ ? completionResponse.getChoices().get(0).getText()
+ : delta.getContent();
Message message = new Message();
if (text != null) {
message.setContent(text);
sseEmitter.send(SseEmitter.event()
- .id(completionResponse.getId())
- .data(message)
- .reconnectTime(3000));
+ .id(completionResponse.getId())
+ .data(message)
+ .reconnectTime(3000));
}
}
@Override
public void onClosed(EventSource eventSource) {
- sseEmitter.complete();
- log.info("OpenAI关闭sse连接...");
+// sseEmitter.complete();
+// log.info("OpenAI关闭sse连接...");
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
+ String scheme = getName();
try {
if (Objects.isNull(response)) {
String message = t.getMessage();
@@ -88,11 +249,11 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) {
Message sseMessage = new Message();
sseMessage.setContent(message);
sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(sseMessage));
+ .id("[ERROR]")
+ .data(sseMessage));
sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
+ .id("[DONE]")
+ .data("[DONE]"));
sseEmitter.complete();
return;
}
@@ -100,22 +261,22 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) {
String bodyString = null;
if (Objects.nonNull(body)) {
bodyString = body.string();
- log.error("OpenAI sse连接异常data:{}", bodyString, t);
+ log.error("{} sse连接异常data:{}",scheme, bodyString, t);
} else {
- log.error("OpenAI sse连接异常data:{}", response, t);
+ log.error("{} sse连接异常data:{}",scheme, response, t);
}
eventSource.cancel();
Message message = new Message();
message.setContent("出现异常,请在帮助中查看详细日志:" + bodyString);
sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(message));
+ .id("[ERROR]")
+ .data(message));
sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
+ .id("[DONE]")
+ .data("[DONE]"));
sseEmitter.complete();
} catch (Exception exception) {
- log.error("发送数据异常:", exception);
+ log.error("{}发送数据异常:", scheme,exception);
}
}
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java
new file mode 100644
index 000000000..503c5543a
--- /dev/null
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java
@@ -0,0 +1,481 @@
+package ai.chat2db.server.web.api.controller.ai.utils;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+import org.apache.commons.collections.MapUtils;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.poi.ss.formula.functions.T;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.stereotype.Service;
+
+import com.alibaba.fastjson2.JSON;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.unfbx.chatgpt.entity.chat.Parameters;
+import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
+
+import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum;
+import ai.chat2db.server.domain.api.model.Config;
+import ai.chat2db.server.domain.api.model.DataSource;
+import ai.chat2db.server.domain.api.param.ShowCreateTableParam;
+import ai.chat2db.server.domain.api.param.TablePageQueryParam;
+import ai.chat2db.server.domain.api.param.TableQueryParam;
+import ai.chat2db.server.domain.api.param.TableSelector;
+import ai.chat2db.server.domain.api.service.ConfigService;
+import ai.chat2db.server.domain.api.service.DataSourceService;
+import ai.chat2db.server.domain.api.service.TableService;
+import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum;
+import ai.chat2db.server.tools.base.wrapper.result.DataResult;
+import ai.chat2db.server.tools.base.wrapper.result.PageResult;
+import ai.chat2db.server.tools.common.util.EasyEnumUtils;
+import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect;
+import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient;
+import ai.chat2db.server.web.api.controller.ai.config.LocalCache;
+import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter;
+import ai.chat2db.server.web.api.controller.ai.enums.PromptType;
+import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient;
+import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole;
+import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
+import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient;
+import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter;
+import ai.chat2db.server.web.api.http.GatewayClientService;
+import ai.chat2db.server.web.api.http.model.TableSchema;
+import ai.chat2db.server.web.api.http.request.TableSchemaRequest;
+import ai.chat2db.server.web.api.http.request.WhiteListRequest;
+import ai.chat2db.server.web.api.http.response.TableSchemaResponse;
+import ai.chat2db.server.web.api.util.ApplicationContextUtil;
+import ai.chat2db.spi.MetaData;
+import ai.chat2db.spi.model.Table;
+import ai.chat2db.spi.model.TableColumn;
+import ai.chat2db.spi.sql.Chat2DBContext;
+import jakarta.annotation.Resource;
+import lombok.extern.slf4j.Slf4j;
+
+
+@Slf4j
+@ConnectionInfoAspect
+@Service
+public class PromptService {
+
+
+ @Value("${chatgpt.context.length}")
+ private Integer contextLength;
+
+
+ @Autowired
+ private TableService tableService;
+
+ @Autowired
+ private DataSourceService dataSourceService;
+
+
+ @Autowired
+ private ChatConverter chatConverter;
+
+
+ @Resource
+ private GatewayClientService gatewayClientService;
+
+
+ @Autowired
+ private RdbWebConverter rdbWebConverter;
+
+
+ /**
+ * 构建prompt
+ *
+ * @param queryRequest
+ * @return
+ */
+ public String buildPrompt(ChatQueryRequest queryRequest) {
+ if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) {
+ return queryRequest.getMessage();
+ }
+
+ // 查询schema信息
+ String dataSourceType = queryDatabaseType(queryRequest);
+ String properties = "";
+ if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
+ TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
+ properties = buildTableColumn(queryParam, queryRequest.getTableNames());
+ } else {
+ properties = mappingDatabaseSchema(queryRequest);
+ }
+ String prompt = queryRequest.getMessage();
+ String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
+ : queryRequest.getPromptType();
+ PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
+ String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
+ String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
+ "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# "
+ + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
+ properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
+ pType.getDescription(), ext, prompt);
+ switch (pType) {
+ case SQL_2_SQL:
+ schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
+ "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
+ "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
+ default:
+ break;
+ }
+ String cleanedInput = schemaProperty.replaceAll("[\r\t]", "");
+ return cleanedInput;
+ }
+
+ public String mappingDatabaseSchema(ChatQueryRequest queryRequest) {
+ String properties = "";
+ String apiKey = getApiKey();
+ if (StringUtils.isNotBlank(apiKey)) {
+ boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData();
+ if (res) {
+// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest);
+ properties = queryDatabaseSchema(queryRequest);
+ }
+ }
+ return properties;
+ }
+
+
+ /**
+ * query chat2db apikey
+ *
+ * @return
+ */
+ public String getApiKey() {
+ ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
+ Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
+ String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode();
+ // only sync for chat2db ai
+ if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) {
+ return null;
+ }
+ Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
+ if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
+ return null;
+ }
+ return keyConfig.getContent();
+ }
+
+ /**
+ * 构建schema参数
+ *
+ * @param tableQueryParam
+ * @param tableNames
+ * @return
+ */
+ public String buildTableColumn(TableQueryParam tableQueryParam,
+ List tableNames) {
+ if (CollectionUtils.isEmpty(tableNames)) {
+ return "";
+ }
+ try {
+ return tableNames.stream().map(tableName -> {
+ tableQueryParam.setTableName(tableName);
+ return queryTableDdl(tableName, tableQueryParam);
+ }).collect(Collectors.joining(";\n"));
+ } catch (Exception exception) {
+ log.error("query table error, do nothing");
+ }
+
+ return "";
+ }
+
+ /**
+ * query table schema
+ *
+ * @param tableName
+ * @param request
+ * @return
+ */
+ public String queryTableDdl(String tableName, TableQueryParam request) {
+ ShowCreateTableParam param = new ShowCreateTableParam();
+ param.setTableName(tableName);
+ param.setDataSourceId(request.getDataSourceId());
+ param.setDatabaseName(request.getDatabaseName());
+ param.setSchemaName(request.getSchemaName());
+ DataResult tableSchema = tableService.showCreateTable(param);
+ return tableSchema.getData();
+ }
+
+ /**
+ * query database schema
+ *
+ * @param queryRequest
+ * @return
+ * @throws IOException
+ */
+ public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
+ // request embedding
+ FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage());
+ List> contentVector = new ArrayList<>();
+ if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) {
+ return "";
+ }
+ contentVector.add(response.getData().get(0).getEmbedding());
+
+ // search embedding
+ TableSchemaRequest tableSchemaRequest = new TableSchemaRequest();
+ tableSchemaRequest.setSchemaVector(contentVector);
+ tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId());
+ tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName());
+ tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName());
+ ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
+ Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData();
+ if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) {
+ return "";
+ }
+ tableSchemaRequest.setApiKey(keyConfig.getContent());
+ try {
+ DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest);
+ List schemas = Lists.newArrayList();
+ if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) {
+ for(TableSchema data: result.getData().getTableSchemas()){
+ schemas.add(data.getTableSchema());
+ }
+ }
+ if (CollectionUtils.isEmpty(schemas)) {
+ return "";
+ }
+ String res = JSON.toJSONString(schemas);
+ log.info("search vector result:{}", res);
+ return res;
+ } catch (Exception exception) {
+ log.error("query table error, do nothing");
+ return "";
+ }
+ }
+
+ /**
+ * distribute embedding with different AI
+ *
+ * @return
+ */
+ public FastChatEmbeddingResponse distributeAIEmbedding(String input) {
+ ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
+ Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData();
+ String aiSqlSource = config.getContent();
+ if (Objects.isNull(aiSqlSource)) {
+ return null;
+ }
+ AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource);
+ switch (Objects.requireNonNull(aiSqlSourceEnum)) {
+ case CHAT2DBAI:
+ return embeddingWithChat2dbAi(input);
+ case FASTCHATAI:
+ return embeddingWithFastChatAi(input);
+ }
+ return null;
+ }
+
+ /**
+ * embedding with fast chat openai
+ *
+ * @param input
+ * @return
+ * @throws IOException
+ */
+ public FastChatEmbeddingResponse embeddingWithFastChatAi(String input) {
+ FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input);
+ return response;
+ }
+
+ /**
+ * embedding with open ai
+ *
+ * @param input
+ * @return
+ */
+ public FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) {
+ FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input);
+ return embeddings;
+ }
+
+ /**
+ * 构建prompt
+ *
+ * @param queryRequest
+ * @return
+ */
+ public String buildAutoPrompt(ChatQueryRequest queryRequest) {
+ if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) {
+ return queryRequest.getMessage();
+ }
+ // 查询schema信息
+ String dataSourceType = queryDatabaseType(queryRequest);
+ String properties = "";
+ if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) {
+ TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest);
+ properties = buildTableColumn(queryParam, queryRequest.getTableNames());
+ } else {
+ properties = queryDatabaseTables(queryRequest);
+ }
+ String prompt = queryRequest.getMessage();
+ String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode()
+ : queryRequest.getPromptType();
+ PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType);
+ if (StringUtils.isNotEmpty(properties)) {
+ pType = PromptType.GET_TABLE_COLUMNS;
+ }
+ String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : "";
+ String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format(
+ "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# "
+ + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType,
+ properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s",
+ pType.getDescription(), ext, prompt);
+ switch (pType) {
+ case SQL_2_SQL:
+ schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format(
+ "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format(
+ "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType);
+ default:
+ break;
+ }
+ String cleanedInput = schemaProperty.replaceAll("[\r\t]", "");
+ return cleanedInput;
+ }
+
+
+ /**
+ * query database type
+ *
+ * @param queryRequest
+ * @return
+ */
+ public String queryDatabaseType(ChatQueryRequest queryRequest) {
+ // 查询schema信息
+ DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId());
+ String dataSourceType = dataResult.getData().getType();
+ if (StringUtils.isBlank(dataSourceType)) {
+ dataSourceType = "MYSQL";
+ }
+ return dataSourceType;
+ }
+
+
+ /**
+ * 根据给定的表对象找出所有可能的外键列
+ * @return 外键列名列表
+ */
+ public static List findPossibleForeignKeys(List columns) {
+ List foreignKeys = new ArrayList<>();
+ for (TableColumn column : columns) {
+ String columnName = column.getName();
+ // 假设TableColumn类有一个getTableName方法可以获取列所属的表名
+ String tableName = column.getTableName();
+ Boolean primaryKey = column.getPrimaryKey();
+
+ // 检查列名是否符合`关联表_id`的格式,并且列名前半部分不等于表名
+ if (columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey)) {
+ // 从列名中移除"_id"以获取可能的关联表名
+ String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3);
+
+ if (!potentialForeignKeyTable.equals(tableName)) {
+ foreignKeys.add(columnName);
+ }
+ }
+ }
+ return foreignKeys;
+ }
+
+ /**
+ * query database schema
+ *
+ * @param queryRequest
+ * @return
+ * @throws IOException
+ */
+ public String queryDatabaseTables(ChatQueryRequest queryRequest) {
+ try {
+ TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(queryRequest);
+ queryParam.queryAll();
+ TableSelector tableSelector = new TableSelector();
+ tableSelector.setColumnList(true);
+ tableSelector.setIndexList(false);
+ PageResult tables = tableService.pageQuery(queryParam,tableSelector);
+ List tableNames = new ArrayList<>();
+ String properties = tables.getData().stream().map(table -> {
+ tableNames.add(table.getName());
+ StringBuilder sb = new StringBuilder(table.getName()); // 直接在初始化时加入表名
+ String comment = table.getComment();
+ List columns = table.getColumnList();
+ List foreignKeys = findPossibleForeignKeys(columns);
+
+ // 只有当有注释或外键时才添加额外信息
+ if(StringUtils.isNotEmpty(comment) || !foreignKeys.isEmpty()){
+ sb.append("(").append(comment);
+
+ // 如果存在外键,添加外键信息
+ if(!foreignKeys.isEmpty()){
+ // 如果注释和外键都存在,先添加一个分隔符
+ if(StringUtils.isNotEmpty(comment)) {
+ sb.append("; ");
+ }
+ sb.append("外键:").append(String.join(", ", foreignKeys)); // 优化外键的展示
+ }
+ sb.append(")");
+ }
+ return sb.toString(); // 在映射阶段直接转换为字符串
+ })
+ .collect(Collectors.joining(","));
+ queryRequest.setTableNames(tableNames);
+ return properties;
+ } catch (Exception e) {
+ log.error("query table error:{}, do nothing", e.getMessage());
+ return "";
+ }
+ }
+
+ public static ToolsFunction getToolsFunction(){
+ return ToolsFunction.builder()
+ .name("get_table_columns")
+ .description(PromptType.GET_TABLE_COLUMNS.getDescription())
+ .parameters(Parameters.builder()
+ .type("object")
+ .properties(ImmutableMap.builder()
+ .put("table_names", ImmutableMap.builder()
+ .put("description", "表名,例如```User```")
+ .put("type", "array")
+ .put("items", ImmutableMap.of("type", "string"))
+ .put("uniqueItems", true)
+ .build())
+ .build())
+ .required(List.of("table_name"))
+ .build())
+ .build();
+ }
+
+
+ /**
+ * get fast chat message
+ *
+ * @param uid
+ * @param prompt
+ * @return
+ */
+ public List getFastChatMessage(String uid, String prompt) {
+ List messages = (List)LocalCache.CACHE.get(uid);
+ if (CollectionUtils.isNotEmpty(messages)) {
+ if (messages.size() >= contextLength) {
+ messages = messages.subList(1, contextLength);
+ }
+ } else {
+ messages = Lists.newArrayList();
+ }
+ FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt);
+ messages.add(currentMessage);
+ return messages;
+ }
+}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java
index f205f17f5..db0d35fa6 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java
@@ -58,8 +58,8 @@ private static ZhipuChatAIStreamClient singleton() {
public static void refresh() {
String apiKey = "";
- String apiHost = "https://open.bigmodel.cn/api/paas/v3/model-api/";
- String model = "chatglm_turbo";
+ String apiHost = "https://open.bigmodel.cn/api/paas/v4/chat/completions";
+ String model = "glm-4";
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(ZHIPU_HOST).getData();
if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) {
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java
index 550c929eb..ef0ec8071 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java
@@ -1,7 +1,6 @@
package ai.chat2db.server.web.api.controller.ai.zhipu.client;
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
-import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import ai.chat2db.server.web.api.controller.ai.zhipu.interceptor.ZhipuChatHeaderAuthorizationInterceptor;
import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions;
import cn.hutool.http.ContentType;
@@ -16,10 +15,8 @@
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
-import org.apache.commons.collections4.CollectionUtils;
import org.jetbrains.annotations.NotNull;
-import java.util.List;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
@@ -69,7 +66,6 @@ public class ZhipuChatAIStreamClient {
@Getter
private OkHttpClient okHttpClient;
-
/**
* @param builder
*/
@@ -90,13 +86,12 @@ private ZhipuChatAIStreamClient(Builder builder) {
* okhttpclient
*/
private OkHttpClient okHttpClient() {
- OkHttpClient okHttpClient = new OkHttpClient
- .Builder()
- .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret))
- .connectTimeout(10, TimeUnit.SECONDS)
- .writeTimeout(50, TimeUnit.SECONDS)
- .readTimeout(50, TimeUnit.SECONDS)
- .build();
+ OkHttpClient okHttpClient = new OkHttpClient.Builder()
+ .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret))
+ .connectTimeout(10, TimeUnit.SECONDS)
+ .writeTimeout(50, TimeUnit.SECONDS)
+ .readTimeout(50, TimeUnit.SECONDS)
+ .build();
return okHttpClient;
}
@@ -184,34 +179,26 @@ public ZhipuChatAIStreamClient build() {
* @param chatMessages
* @param eventSourceListener
*/
- public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) {
- if (CollectionUtils.isEmpty(chatMessages)) {
- log.error("param error:Zhipu Chat Prompt cannot be empty");
- throw new ParamBusinessException("prompt");
- }
+ public void streamCompletions(ZhipuChatCompletionsOptions completionsOptions, EventSourceListener eventSourceListener) {
+
if (Objects.isNull(eventSourceListener)) {
log.error("param error:Zhipu ChatEventSourceListener cannot be empty");
throw new ParamBusinessException();
}
- log.info("Zhipu Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
+ completionsOptions.setModel(this.model);
try {
- // 建议直接查看demo包代码,这里更新可能不及时
- ZhipuChatCompletionsOptions completionsOptions = new ZhipuChatCompletionsOptions();
- completionsOptions.setPrompt(chatMessages);
- completionsOptions.setModel(this.model);
- String requestId = String.valueOf(System.currentTimeMillis());
- completionsOptions.setRequestId(requestId);
+
ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
String requestBody = mapper.writeValueAsString(completionsOptions);
- String url = this.apiHost + "/" + this.model + "/" + "sse-invoke";
+ String url = this.apiHost;
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
Request request = new Request.Builder()
- .url(url)
- .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
- .build();
- //创建事件
+ .url(url)
+ .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
+ .build();
+ // 创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
log.info("finish invoking zhipu chat ai");
} catch (Exception e) {
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java
index a8b1ae016..a02668775 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java
@@ -1,22 +1,24 @@
package ai.chat2db.server.web.api.controller.ai.zhipu.listener;
+import ai.chat2db.server.tools.common.model.LoginUser;
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
-import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions;
-import com.fasterxml.jackson.databind.DeserializationFeature;
-import com.fasterxml.jackson.databind.ObjectMapper;
-import com.unfbx.chatgpt.entity.chat.Message;
-import lombok.SneakyThrows;
+import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole;
+import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener;
+import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
+import ai.chat2db.server.web.api.controller.ai.utils.PromptService;
+import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient;
+import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions;
import lombok.extern.slf4j.Slf4j;
-import okhttp3.Response;
-import okhttp3.ResponseBody;
-import okhttp3.sse.EventSource;
-import okhttp3.sse.EventSourceListener;
-import org.apache.commons.lang3.StringUtils;
-import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
-import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Objects;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import com.unfbx.chatgpt.entity.chat.tool.Tools;
+import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction;
+
/**
* 描述:OpenAIEventSourceListener
*
@@ -24,111 +26,32 @@
* @date 2023-02-22
*/
@Slf4j
-public class ZhipuChatAIEventSourceListener extends EventSourceListener {
-
- private SseEmitter sseEmitter;
-
- private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
-
- public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter) {
- this.sseEmitter = sseEmitter;
- }
-
- /**
- * {@inheritDoc}
- */
- @Override
- public void onOpen(EventSource eventSource, Response response) {
- log.info("Zhipu Chat Sse connecting...");
- }
-
- /**
- * {@inheritDoc}
- */
- @SneakyThrows
- @Override
- public void onEvent(EventSource eventSource, String id, String type, String data) {
- log.info("Zhipu Chat AI response data:{}", data);
- if (data.equals("[DONE]")) {
- log.info("Zhipu Chat AI closed");
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]")
- .reconnectTime(3000));
- sseEmitter.complete();
- return;
- }
-
- ZhipuChatCompletions chatCompletions = mapper.readValue(data, ZhipuChatCompletions.class);
- String text = chatCompletions.getData();
- if (Objects.isNull(text)) {
- for (FastChatMessage message : chatCompletions.getBody().getChoices()) {
- if (message != null && message.getContent() != null) {
- text = message.getContent();
- }
- }
- }
-
- Message message = new Message();
- message.setContent(text);
- sseEmitter.send(SseEmitter.event()
- .id(null)
- .data(message)
- .reconnectTime(3000));
+public class ZhipuChatAIEventSourceListener extends OpenAIEventSourceListener {
+
+ public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService,
+ ChatQueryRequest queryRequest, LoginUser loginUser) {
+ super(sseEmitter, promptService, queryRequest, loginUser);
}
@Override
- public void onClosed(EventSource eventSource) {
- try {
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- sseEmitter.complete();
- log.info("ZhipuChatAI close sse connection...");
+ public String getName(){
+ return "Zhipu";
}
@Override
- public void onFailure(EventSource eventSource, Throwable t, Response response) {
- try {
- if (Objects.isNull(response)) {
- String message = t.getMessage();
- Message sseMessage = new Message();
- sseMessage.setContent(message);
- sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(sseMessage));
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- sseEmitter.complete();
- return;
- }
- ResponseBody body = response.body();
- String bodyString = Objects.nonNull(t) ? t.getMessage() : "";
- if (Objects.nonNull(body)) {
- bodyString = body.string();
- if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) {
- bodyString = t.getMessage();
- }
- log.error("Zhipu Chat AI sse response:{}", bodyString);
- } else {
- log.error("Zhipu Chat AI sse response:{},error:{}", response, t);
- }
- eventSource.cancel();
- Message message = new Message();
- message.setContent("Zhipu Chat AI error:" + bodyString);
- sseEmitter.send(SseEmitter.event()
- .id("[ERROR]")
- .data(message));
- sseEmitter.send(SseEmitter.event()
- .id("[DONE]")
- .data("[DONE]"));
- sseEmitter.complete();
- } catch (Exception exception) {
- log.error("Zhipu Chat AI send data error:", exception);
- }
+ public void functionCall(String prompt){
+ FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt);
+ List messages = new ArrayList<>();
+ messages.add(currentMessage);
+ String requestId = String.valueOf(System.currentTimeMillis());
+ ToolsFunction function = PromptService.getToolsFunction();
+ ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder()
+ .requestId(requestId)
+ .stream(true)
+ .toolChoice("auto")
+ .tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function)))
+ .messages(messages)
+ .build();
+ ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, this);
}
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java
index 4b6359cc2..06c16bd07 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java
@@ -5,6 +5,9 @@
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
import com.fasterxml.jackson.annotation.JsonProperty;
+import com.unfbx.chatgpt.entity.chat.tool.Tools;
+
+import lombok.Builder;
import lombok.Data;
import java.util.List;
@@ -14,18 +17,16 @@
* generate text that continues from or "completes" provided prompt data.
*/
@Data
+@Builder
public final class ZhipuChatCompletionsOptions {
@JsonProperty(value = "request_id")
private String requestId;
// sse-params
- @JsonProperty(value = "incremental")
+ @JsonProperty(value = "stream")
private Boolean stream = true;
- @JsonProperty(value = "sseFormat")
- private String sseFormat = "data";
-
/*
* The collection of context messages associated with this chat completions request.
@@ -33,8 +34,8 @@ public final class ZhipuChatCompletionsOptions {
* the behavior of the assistant, followed by alternating messages between the User and
* Assistant roles.
*/
- @JsonProperty(value = "prompt")
- private List prompt;
+ @JsonProperty(value = "messages")
+ private List messages;
//
@@ -45,4 +46,14 @@ public final class ZhipuChatCompletionsOptions {
*/
@JsonProperty(value = "model")
private String model;
+
+
+
+ // 新添加的参数
+ @JsonProperty(value = "tool_choice")
+ private String toolChoice; // 工具选择策略
+
+ @JsonProperty(value = "tools")
+ private List tools; // 工具列表
+
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java
index 131a6bf6c..f2ce64dfd 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java
@@ -18,17 +18,18 @@
import ai.chat2db.server.web.api.controller.rdb.vo.SqlVO;
import ai.chat2db.server.web.api.controller.rdb.vo.TableVO;
import ai.chat2db.spi.model.*;
-import ai.chat2db.spi.sql.Chat2DBContext;
-import ai.chat2db.spi.sql.ConnectInfo;
import com.google.common.collect.Lists;
import jakarta.validation.Valid;
import lombok.extern.slf4j.Slf4j;
+
+import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
@Slf4j
@ConnectionInfoAspect
@@ -250,4 +251,40 @@ public ActionResult delete(@Valid @RequestBody TableDeleteRequest request) {
DropParam dropParam = rdbWebConverter.tableDelete2dropParam(request);
return tableService.drop(dropParam);
}
+
+
+ /**
+ * 查询ER图
+ *
+ * @param request
+ * @return
+ */
+ @GetMapping("/er-diagram")
+ public DataResult erDiagram(@Valid TableBriefQueryRequest request) {
+ TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(request);
+ TableSelector tableSelector = new TableSelector();
+ tableSelector.setColumnList(true);
+ tableSelector.setIndexList(false);
+ PageResult tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector);
+ List entityList = tableDTOPageResult.getData().stream().map(table -> {
+ ErDiagram.Node entity = new ErDiagram.Node(table.getName(),
+ StringUtils.defaultIfBlank(table.getComment(), table.getName()));
+ return entity;
+ }).collect(Collectors.toList());
+ List relationList = tableDTOPageResult.getData().stream().flatMap(table -> {
+ return table.getColumnList().stream().filter(column -> {
+ String columnName = column.getName();
+ Boolean primaryKey = column.getPrimaryKey();
+ return columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey);
+ }).map(column -> {
+ String columnName = column.getName();
+ String tableName = column.getTableName();
+ // 从列名中移除"_id"以获取可能的关联表名
+ String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3);
+ ErDiagram.Edge relation = new ErDiagram.Edge(columnName,tableName, potentialForeignKeyTable,column.getComment());
+ return relation;
+ });
+ }).collect(Collectors.toList());
+ return DataResult.of(new ErDiagram(entityList, relationList));
+ }
}
diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java
index b04663dc9..5a37c352a 100644
--- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java
+++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java
@@ -3,6 +3,7 @@
import java.util.List;
import ai.chat2db.server.domain.api.param.*;
+import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.data.source.vo.DatabaseVO;
import ai.chat2db.server.web.api.controller.rdb.request.*;
import ai.chat2db.server.web.api.controller.rdb.vo.ColumnVO;
@@ -99,6 +100,14 @@ public abstract class RdbWebConverter {
* @return
*/
public abstract SqlVO dto2vo(Sql dto);
+
+ /**
+ * 参数转换
+ *
+ * @param request
+ * @return
+ */
+ public abstract TablePageQueryParam tablePageRequest2param(ChatQueryRequest request);
/**
* 参数转换
*
diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java
new file mode 100644
index 000000000..67a2f9920
--- /dev/null
+++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java
@@ -0,0 +1,42 @@
+package ai.chat2db.spi.model;
+
+import java.util.List;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import lombok.experimental.SuperBuilder;
+
+/**
+ * er图
+ */
+@Data
+@SuperBuilder
+@NoArgsConstructor
+@AllArgsConstructor
+public class ErDiagram {
+
+ private List nodes;
+ private List edges;
+
+ @Data
+ @SuperBuilder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Node {
+ private String id;
+ private String label;
+ }
+
+ @Data
+ @SuperBuilder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Edge {
+ private String id;
+ private String source;
+ private String target;
+ private String label;
+ }
+
+}
diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java
index 9e6fce81a..88183d9ad 100644
--- a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java
+++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java
@@ -142,6 +142,7 @@ public static void removeContext() {
try {
if (connection != null && !connection.isClosed()) {
connection.close();
+ connectInfo.setConnection(null);
}
} catch (SQLException e) {
log.error("close connection error", e);
diff --git a/chat2db-server/pom.xml b/chat2db-server/pom.xml
index 16c693477..b5b0cf43a 100644
--- a/chat2db-server/pom.xml
+++ b/chat2db-server/pom.xml
@@ -222,7 +222,7 @@
com.unfbx
chatgpt-java
- 1.0.8
+ 1.1.5
org.slf4j