diff --git a/.gitignore b/.gitignore index f8a263aaa..793134d36 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ package-lock.json /chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/src/main/resources/lib/* /chat2db-server/ali-dbhub-server-domain/ali-dbhub-server-domain-support/lib/* /lib -/out/* \ No newline at end of file +/out/* +/chat2db-gateway/target diff --git a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx index 592755979..997b8bd3d 100644 --- a/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx +++ b/chat2db-client/src/components/ConsoleEditor/components/ChatInput/index.tsx @@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => { }; const renderSelectTable = () => { - const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props; + const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props; const options = (tables || []).map((t) => ({ value: t, label: t })); return (
onSelectTableSyncModel(v.target.value)} - // value={syncTableModel} - value={SyncModelType.MANUAL} + value={syncTableModel} style={{ marginBottom: '8px' }} > - {/* 自动 */} + 自动 手动 diff --git a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx index fc7e36403..6a50f80e9 100644 --- a/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx +++ b/chat2db-client/src/components/ConsoleEditor/components/SelectBoundInfo/index.tsx @@ -184,7 +184,7 @@ const SelectBoundInfo = memo((props: IProps) => { boundInfo.databaseName, boundInfo.schemaName, ); - setSelectedTables(tableNameListTemp.slice(0, 1)); + //setSelectedTables(tableNameListTemp.slice(0, 1)); } }, [allTableList, isActive]); diff --git a/chat2db-gateway/pom.xml b/chat2db-gateway/pom.xml new file mode 100644 index 000000000..68d8cbc0d --- /dev/null +++ b/chat2db-gateway/pom.xml @@ -0,0 +1,84 @@ + + + 4.0.0 + + com.hejianjun + chat2db-gateway + 0.0.1-SNAPSHOT + jar + + chat2db-gateway + Project for chat2db-gateway + + + org.springframework.boot + spring-boot-starter-parent + 2.6.7 + + + + + 11 + 8.12.2 + 2.0.1 + + + + + + org.springframework.boot + spring-boot-starter-web + + + + + co.elastic.clients + elasticsearch-java + 8.12.2 + + + jakarta.json + jakarta.json-api + ${jakarta-json.version} + + + com.fasterxml.jackson.core + jackson-databind + 2.12.3 + + + + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + org.projectlombok + lombok + true + + + org.springframework.boot + spring-boot-starter-validation + + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + + diff --git a/chat2db-gateway/src/main/java/com/hejianjun/Application.java b/chat2db-gateway/src/main/java/com/hejianjun/Application.java new file mode 100644 index 000000000..23f0f58f3 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/Application.java @@ -0,0 +1,18 @@ +package com.hejianjun; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@Slf4j +@SpringBootApplication +public class Application { + /** + * 主程序入口 + * @param args 命令行参数 + */ + public static void main(String[] args) { + SpringApplication.run(Application.class, args); + } + +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java new file mode 100644 index 000000000..82432390c --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/SchemaDocument.java @@ -0,0 +1,14 @@ +package com.hejianjun.bean; + +import lombok.AllArgsConstructor; +import lombok.Data; + +import java.math.BigDecimal; +import java.util.List; + +@Data +@AllArgsConstructor +public class SchemaDocument { + private String schema; + private List vector; +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java new file mode 100644 index 000000000..e3398b435 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/bean/TableSchemaRequest.java @@ -0,0 +1,39 @@ +package com.hejianjun.bean; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +import javax.validation.constraints.NotNull; +import java.math.BigDecimal; +import java.util.List; + +/** + * 表结构请求 + */ +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class TableSchemaRequest { + + // 数据源ID + @NotNull + private Long dataSourceId; + // 数据库名称 + @NotNull + private String databaseName; + // API密钥 + private String apiKey; + // 数据源模式 + private String dataSourceSchema; + // 模式向量 + @NotNull + private List> schemaVector; + // 模式列表 + @NotNull + private List schemaList; + // 插入前删除 + private Boolean deleteBeforeInsert = false; +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java new file mode 100644 index 000000000..c95bc67f5 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/config/ElasticsearchClientConfig.java @@ -0,0 +1,41 @@ +package com.hejianjun.config; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.json.jackson.JacksonJsonpMapper; +import co.elastic.clients.transport.ElasticsearchTransport; +import co.elastic.clients.transport.rest_client.RestClientTransport; +import org.apache.http.Header; +import org.apache.http.HttpHost; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.client.RestClient; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class ElasticsearchClientConfig { + + String apiKey = "0E9NGIy7gb8a3TDVM8dC"; + + /** + * 创建ElasticsearchClient实例 + * + * @return ElasticsearchClient实例 + */ + @Bean + public ElasticsearchClient elasticsearchClient() { + // 初始化低级客户端 + RestClient restClient = RestClient.builder(new HttpHost("localhost", 9200)) + .setDefaultHeaders(new Header[]{ + new BasicHeader("Authorization", "ApiKey " + apiKey) + }) + .build(); + + // 使用低级客户端创建传输层 + ElasticsearchTransport transport = new RestClientTransport( + restClient, new JacksonJsonpMapper()); + + // 创建ElasticsearchClient实例 + return new ElasticsearchClient(transport); + } + +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java new file mode 100644 index 000000000..5abb24133 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/controller/TableSchemaController.java @@ -0,0 +1,56 @@ +package com.hejianjun.controller; + +import com.hejianjun.bean.TableSchemaRequest; +import com.hejianjun.service.TableSchemaService; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RestController; + +import java.io.IOException; +import java.util.List; + +@Slf4j +@RestController +@AllArgsConstructor +@RequestMapping("/api/client/milvus") +public class TableSchemaController { + + private final TableSchemaService service; + + + /** + * 保存表结构 + * @param request 表结构请求对象 + * @return 保存成功的文档ID + */ + @PostMapping("/schema/save") + public ResponseEntity> saveSchema(@RequestBody TableSchemaRequest request) { + try { + List documentId = service.saveSchemaBatch(request); + return ResponseEntity.ok(documentId); + } catch (IOException e) { + log.error("保存表结构时发生错误", e); + return ResponseEntity.internalServerError().build(); + } + } + + /** + * 通过向量搜索表结构 + * @param request 表结构搜索请求 + * @return 搜索结果列表 + */ + @PostMapping("/schema/search") + public ResponseEntity searchByVector(@RequestBody TableSchemaRequest request) { + try { + TableSchemaRequest tableSchemaRequest = service.searchByVector(request); + return ResponseEntity.ok(tableSchemaRequest); + } catch (IOException e) { + log.error("Error searching schema", e); + return ResponseEntity.internalServerError().build(); + } + } +} diff --git a/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java new file mode 100644 index 000000000..3668f3358 --- /dev/null +++ b/chat2db-gateway/src/main/java/com/hejianjun/service/TableSchemaService.java @@ -0,0 +1,105 @@ +package com.hejianjun.service; + +import co.elastic.clients.elasticsearch.ElasticsearchClient; +import co.elastic.clients.elasticsearch.core.BulkRequest; +import co.elastic.clients.elasticsearch.core.BulkResponse; +import co.elastic.clients.elasticsearch.core.SearchResponse; +import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem; +import co.elastic.clients.elasticsearch.core.search.Hit; +import com.hejianjun.bean.SchemaDocument; +import com.hejianjun.bean.TableSchemaRequest; +import lombok.AllArgsConstructor; +import org.springframework.stereotype.Service; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.List; + +/** + * TableSchemaService类用于处理表结构相关的操作。 + */ +@Service +@AllArgsConstructor +public class TableSchemaService { + + private final ElasticsearchClient client; + + /** + * 批量保存表结构。 + * + * @param request 表结构请求对象 + * @return 保存成功后的每个文档的ID列表 + * @throws IOException IO异常 + */ + public List saveSchemaBatch(TableSchemaRequest request) throws IOException { + List documentIds = new ArrayList<>(); + + // 构建批量请求 + BulkRequest.Builder bulkBuilder = new BulkRequest.Builder(); + + String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema(); + + for (int i = 0; i < request.getSchemaVector().size(); i++) { + // 假设schemaVector和schemaList的长度相同,并且一一对应 + List vector = request.getSchemaVector().get(i); + String schema = request.getSchemaList().get(i); + + // 创建文档内容,这里简化为Map,具体结构根据需求定义 + SchemaDocument document = new SchemaDocument(schema,vector); + + // 添加到批量请求 + bulkBuilder.operations(op -> op + .index(idx -> idx + .index(indexName) + .document(document) + ) + ); + } + + // 执行批量请求 + BulkResponse bulkResponse = client.bulk(bulkBuilder.build()); + + // 收集文档ID + for (BulkResponseItem item : bulkResponse.items()) { + if (item.error()!=null) { + throw new IOException("Error indexing document: " + item.error().reason()); + } + documentIds.add(item.id()); + } + + return documentIds; + } + + /** + * 根据向量搜索表结构。 + * + * @param request 表结构请求对象 + * @return 搜索结果列表 + * @throws IOException IO异常 + */ + public TableSchemaRequest searchByVector(TableSchemaRequest request) throws IOException { + String indexName = request.getDataSourceId() + request.getDatabaseName() + request.getDataSourceSchema(); + List vector = request.getSchemaVector().get(0); + // 假设schemaVector已转换为适合Elasticsearch的格式 + // 执行k-NN搜索 + SearchResponse response = client.search(s -> s + .index(indexName) + // 这里添加k-NN查询逻辑,具体实现根据实际需求 + , SchemaDocument.class + ); + List> schemaVector = new ArrayList<>(); + List schemaList = new ArrayList<>(); + List> hits = response.hits().hits(); + for (Hit hit: hits) { + SchemaDocument document = hit.source(); + if(document!=null) { + schemaVector.add(document.getVector()); + schemaList.add(document.getSchema()); + } + } + request.setSchemaVector(schemaVector); + request.setSchemaList(schemaList); + return request; + } +} diff --git a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java index 40a291955..d08cc4a6a 100644 --- a/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java +++ b/chat2db-server/chat2db-plugins/chat2db-mysql/src/main/java/ai/chat2db/plugin/mysql/MysqlMetaData.java @@ -33,8 +33,13 @@ public List databases(Connection connection) { @Override public String tableDDL(Connection connection, @NotEmpty String databaseName, String schemaName, @NotEmpty String tableName) { - String sql = "SHOW CREATE TABLE " + format(databaseName) + "." - + format(tableName); + String sql; + if(StringUtils.isEmpty(databaseName)) { + sql = "SHOW CREATE TABLE " + format(tableName); + }else{ + sql = "SHOW CREATE TABLE " + format(databaseName) + "." + + format(tableName); + } return SQLExecutor.getInstance().execute(connection, sql, resultSet -> { if (resultSet.next()) { return resultSet.getString("Create Table"); diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java index 1454772d8..46cf1e033 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-core/src/main/java/ai/chat2db/server/domain/core/impl/TableServiceImpl.java @@ -339,6 +339,16 @@ public PageResult pageQuery(TablePageQueryParam param, TableSelector sele t.setComment(tableCacheDO.getExtendInfo()); t.setSchemaName(tableCacheDO.getSchemaName()); t.setDatabaseName(tableCacheDO.getDatabaseName()); + if(Boolean.TRUE.equals(selector.getColumnList())){ + TableQueryParam tableQueryParam = new TableQueryParam(); + tableQueryParam.setDataSourceId(param.getDataSourceId()); + tableQueryParam.setDatabaseName(param.getDatabaseName()); + tableQueryParam.setSchemaName(param.getSchemaName()); + tableQueryParam.setTableName(tableCacheDO.getTableName()); + tableQueryParam.setRefresh(false); + List columns = queryColumns(tableQueryParam); + t.setColumnList(columns); + } tables.add(t); } } @@ -419,44 +429,41 @@ public ListResult queryTables(TablePageQueryParam param) { private long addDBCache(Long dataSourceId, String databaseName, String schemaName, long version) { String key = getTableKey(dataSourceId, databaseName, schemaName); - Connection connection = Chat2DBContext.getConnection(); long n = 0; - try (ResultSet resultSet = connection.getMetaData().getTables(databaseName, schemaName, null, - new String[]{"TABLE", "SYSTEM TABLE"})) { - List cacheDOS = new ArrayList<>(); - while (resultSet.next()) { - TableCacheDO tableCacheDO = new TableCacheDO(); - tableCacheDO.setDatabaseName(databaseName); - tableCacheDO.setSchemaName(schemaName); - tableCacheDO.setTableName(resultSet.getString("TABLE_NAME")); - tableCacheDO.setExtendInfo(resultSet.getString("REMARKS")); - tableCacheDO.setDataSourceId(dataSourceId); - tableCacheDO.setVersion(version); - tableCacheDO.setKey(key); - cacheDOS.add(tableCacheDO); - if (cacheDOS.size() >= 500) { - getTableCacheMapper().batchInsert(cacheDOS); - cacheDOS = new ArrayList<>(); - } - n++; - } - if (!CollectionUtils.isEmpty(cacheDOS)) { + MetaData metaSchema = Chat2DBContext.getMetaData(); + List
tables = metaSchema.tables(connection, databaseName, schemaName, null); + List cacheDOS = new ArrayList<>(); + for(Table table : tables){ + TableCacheDO tableCacheDO = new TableCacheDO(); + tableCacheDO.setDatabaseName(databaseName); + tableCacheDO.setSchemaName(schemaName); + tableCacheDO.setTableName(table.getName()); + tableCacheDO.setExtendInfo(table.getComment()); + tableCacheDO.setDataSourceId(dataSourceId); + tableCacheDO.setVersion(version); + tableCacheDO.setKey(key); + metaSchema.columns(connection, databaseName, schemaName, table.getName()); + cacheDOS.add(tableCacheDO); + if (cacheDOS.size() >= 500) { getTableCacheMapper().batchInsert(cacheDOS); + cacheDOS = new ArrayList<>(); } - LambdaQueryWrapper q = new LambdaQueryWrapper(); - q.eq(TableCacheDO::getDataSourceId, dataSourceId); - q.lt(TableCacheDO::getVersion, version); - if (StringUtils.isNotBlank(databaseName)) { - q.eq(TableCacheDO::getDatabaseName, databaseName); - } - if (StringUtils.isNotBlank(schemaName)) { - q.eq(TableCacheDO::getSchemaName, schemaName); - } - getTableCacheMapper().delete(q); - } catch (SQLException e) { - throw new RuntimeException(e); + n++; + } + if (!CollectionUtils.isEmpty(cacheDOS)) { + getTableCacheMapper().batchInsert(cacheDOS); + } + LambdaQueryWrapper q = new LambdaQueryWrapper(); + q.eq(TableCacheDO::getDataSourceId, dataSourceId); + q.lt(TableCacheDO::getVersion, version); + if (StringUtils.isNotBlank(databaseName)) { + q.eq(TableCacheDO::getDatabaseName, databaseName); + } + if (StringUtils.isNotBlank(schemaName)) { + q.eq(TableCacheDO::getSchemaName, schemaName); } + getTableCacheMapper().delete(q); return n; } @@ -480,7 +487,7 @@ private Long getLock(Long dataSourceId, String databaseName, String schemaName, } } else { long version = versionDO.getVersion() + 1; - LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper(); + LambdaQueryWrapper queryWrapper = new LambdaQueryWrapper<>(); queryWrapper.eq(TableCacheVersionDO::getId, versionDO.getId()); queryWrapper.eq(TableCacheVersionDO::getVersion, versionDO.getVersion()); versionDO.setVersion(version); diff --git a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml index c367e2605..37efbd21c 100644 --- a/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml +++ b/chat2db-server/chat2db-server-domain/chat2db-server-domain-repository/src/main/resources/mapper/TableCacheMapper.xml @@ -25,7 +25,7 @@ and tc.schema_name = #{schemaName} - and LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%')) + and (LOWER(tc.table_name) like LOWER(concat('%',#{searchKey},'%')) or tc.extend_info like concat('%',#{searchKey},'%')) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java new file mode 100644 index 000000000..01e4c0eef --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/aspect/GatewayClientServiceAspect.java @@ -0,0 +1,33 @@ +package ai.chat2db.server.web.api.aspect; + + +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.springframework.stereotype.Component; + +@Aspect +@Component +public class GatewayClientServiceAspect { + /** + * 定义切点,匹配 GatewayClientService 类中的所有方法 + */ + @Pointcut("execution(* ai.chat2db.server.web.api.http.GatewayClientService.*(..)) && !execution(* ai.chat2db.server.web.api.http.GatewayClientService.checkInWhite(..))") + public void gatewayClientServiceMethods() {} + + + + /** + * 环绕通知:在切点方法执行时触发 + * @param joinPoint + * @return + * @throws Throwable + */ + @Around("gatewayClientServiceMethods()") + public Object aroundGatewayClientServiceMethods(ProceedingJoinPoint joinPoint) throws Throwable { + // 这里你可以执行一些自定义的逻辑,如果需要的话 + // 然后返回 null 或其他默认值 + return null; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java index c9e77806f..8a4ca74eb 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java @@ -1,20 +1,17 @@ package ai.chat2db.server.web.api.controller.ai; + + import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; import ai.chat2db.server.domain.api.model.Config; -import ai.chat2db.server.domain.api.model.DataSource; -import ai.chat2db.server.domain.api.param.ShowCreateTableParam; -import ai.chat2db.server.domain.api.param.TableQueryParam; import ai.chat2db.server.domain.api.service.ConfigService; -import ai.chat2db.server.domain.api.service.DataSourceService; -import ai.chat2db.server.domain.api.service.TableService; -import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; -import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; -import ai.chat2db.server.tools.common.util.EasyEnumUtils; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.tools.common.util.ContextUtils; import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.listener.AzureOpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.baichuan.client.BaichuanAIClient; @@ -26,11 +23,9 @@ import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.claude.model.ClaudeChatMessage; import ai.chat2db.server.web.api.controller.ai.config.LocalCache; -import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; -import ai.chat2db.server.web.api.controller.ai.enums.PromptType; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; -import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; import ai.chat2db.server.web.api.controller.ai.fastchat.listener.FastChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient; @@ -41,42 +36,36 @@ import ai.chat2db.server.web.api.controller.ai.rest.listener.RestAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.tongyi.client.TongyiChatAIClient; import ai.chat2db.server.web.api.controller.ai.tongyi.listener.TongyiChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; import ai.chat2db.server.web.api.controller.ai.wenxin.client.WenxinAIClient; import ai.chat2db.server.web.api.controller.ai.wenxin.listener.WenxinAIEventSourceListener; import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; import ai.chat2db.server.web.api.controller.ai.zhipu.listener.ZhipuChatAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import ai.chat2db.server.web.api.http.GatewayClientService; -import ai.chat2db.server.web.api.http.model.EsTableSchema; -import ai.chat2db.server.web.api.http.model.TableSchema; -import ai.chat2db.server.web.api.http.request.EsTableSchemaRequest; -import ai.chat2db.server.web.api.http.request.TableSchemaRequest; -import ai.chat2db.server.web.api.http.request.WhiteListRequest; -import ai.chat2db.server.web.api.http.response.EsTableSchemaResponse; -import ai.chat2db.server.web.api.http.response.TableSchemaResponse; import ai.chat2db.server.web.api.util.ApplicationContextUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; -import com.alibaba.fastjson2.JSON; import com.google.common.collect.Lists; +import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; import jakarta.annotation.Resource; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.web.bind.annotation.*; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; import java.io.IOException; -import java.math.BigDecimal; import java.time.Duration; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; /** * 描述: @@ -90,14 +79,6 @@ @Slf4j public class ChatController { - @Autowired - private TableService tableService; - - @Autowired - private ChatConverter chatConverter; - - @Autowired - private DataSourceService dataSourceService; @Value("${chatgpt.context.length}") private Integer contextLength; @@ -108,6 +89,10 @@ public class ChatController { @Resource private GatewayClientService gatewayClientService; + + @Resource + protected PromptService promptService; + /** * chat的超时时间 */ @@ -171,7 +156,7 @@ public SseEmitter customChat(@RequestBody ChatRequest queryRequest) throws IOExc /** * 自定义模型非流式输出接口DEMO *

- * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致 + * Note:使用自己本地的飞流式输出自定义AI,接口输入和输出需与该样例保持一致 *

* * @param queryRequest @@ -262,7 +247,7 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE */ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) { RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter); - RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener); + RestAIClient.getInstance().restCompletions(promptService.buildPrompt(prompt), eventSourceListener); return sseEmitter; } @@ -276,11 +261,11 @@ private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter * @throws IOException */ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) - throws IOException { - String prompt = buildPrompt(queryRequest); + throws IOException { + String prompt = promptService.buildAutoPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, - prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); + prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); throw new ParamBusinessException(); } @@ -290,9 +275,17 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); messages.add(currentMessage); buildSseEmitter(sseEmitter, uid); - - OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter); - OpenAIClient.getInstance().streamChatCompletion(messages, openAIEventSourceListener); + LoginUser loginUser = ContextUtils.getLoginUser(); + OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, promptService, queryRequest,loginUser); + ChatCompletion chatCompletion = ChatCompletion.builder() + .messages(messages).stream(true).build(); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = PromptService.getToolsFunction(); + chatCompletion.setModel("gpt-3.5-turbo-0125"); + chatCompletion.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + chatCompletion.setToolChoice("auto"); + } + OpenAIClient.getInstance().streamChatCompletion(chatCompletion, openAIEventSourceListener); LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT); return sseEmitter; } @@ -308,7 +301,7 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE */ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("exceed max token length:{},input length:{}", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); @@ -338,11 +331,13 @@ private SseEmitter chatWithChat2dbAi(ChatQueryRequest queryRequest, SseEmitter s * @throws IOException */ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildAutoPrompt(queryRequest); if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) { log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH, prompt.length() / TOKEN_CONVERT_CHAR_LENGTH); throw new ParamBusinessException(); + }else{ + log.info("提示词 :{}",prompt); } List messages = (List)LocalCache.CACHE.get(uid); if (CollectionUtils.isNotEmpty(messages)) { @@ -356,9 +351,16 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse messages.add(currentMessage); buildSseEmitter(sseEmitter, uid); - - AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter); - AzureOpenAIClient.getInstance().streamCompletions(messages, sourceListener); + LoginUser loginUser = ContextUtils.getLoginUser(); + AzureOpenAIEventSourceListener sourceListener = new AzureOpenAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser); + AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages); + chatCompletionsOptions.setStream(true); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = PromptService.getToolsFunction(); + chatCompletionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + chatCompletionsOptions.setToolChoice("auto"); + } + AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, sourceListener); LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); return sseEmitter; } @@ -373,8 +375,8 @@ private SseEmitter chatWithAzureAi(ChatQueryRequest queryRequest, SseEmitter sse * @throws IOException */ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + String prompt = promptService.buildPrompt(queryRequest); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -394,13 +396,27 @@ private SseEmitter chatWithFastChatAi(ChatQueryRequest queryRequest, SseEmitter * @throws IOException */ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + String prompt = promptService.buildAutoPrompt(queryRequest); + log.info("原始提示词{}",prompt); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); - - ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter); - ZhipuChatAIClient.getInstance().streamCompletions(messages, sourceListener); + LoginUser loginUser = ContextUtils.getLoginUser(); + ZhipuChatAIEventSourceListener sourceListener = new ZhipuChatAIEventSourceListener(sseEmitter,promptService,queryRequest,loginUser); + String requestId = String.valueOf(System.currentTimeMillis()); + // 建议直接查看demo包代码,这里更新可能不及时 + ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() + .requestId(requestId) + .stream(true) + + .messages(messages) + .build(); + if(queryRequest.getDatabaseName()!=null){ + ToolsFunction function = PromptService.getToolsFunction(); + completionsOptions.setTools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))); + completionsOptions.setToolChoice("auto"); + } + ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, sourceListener); LocalCache.CACHE.put(uid, messages, LocalCache.TIMEOUT); return sseEmitter; } @@ -415,8 +431,8 @@ private SseEmitter chatWithZhipuChatAi(ChatQueryRequest queryRequest, SseEmitter * @throws IOException */ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + String prompt = promptService.buildPrompt(queryRequest); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -436,8 +452,8 @@ private SseEmitter chatWithTongyiChatAi(ChatQueryRequest queryRequest, SseEmitte * @throws IOException */ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + String prompt = promptService.buildPrompt(queryRequest); + List messages = promptService.getFastChatMessage(uid, prompt); buildSseEmitter(sseEmitter, uid); @@ -447,26 +463,7 @@ private SseEmitter chatWithBaichuanAi(ChatQueryRequest queryRequest, SseEmitter return sseEmitter; } - /** - * get fast chat message - * - * @param uid - * @param prompt - * @return - */ - private List getFastChatMessage(String uid, String prompt) { - List messages = (List)LocalCache.CACHE.get(uid); - if (CollectionUtils.isNotEmpty(messages)) { - if (messages.size() >= contextLength) { - messages = messages.subList(1, contextLength); - } - } else { - messages = Lists.newArrayList(); - } - FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); - messages.add(currentMessage); - return messages; - } + /** * chat with wenxin chat openai @@ -478,8 +475,8 @@ private List getFastChatMessage(String uid, String prompt) { * @throws IOException */ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); - List messages = getFastChatMessage(uid, prompt); + String prompt = promptService.buildPrompt(queryRequest); + List messages = promptService.getFastChatMessage(uid, prompt); if (messages.size() >= 2 && messages.size() % 2 == 0) { messages.remove(messages.size() - 1); } @@ -503,7 +500,7 @@ private SseEmitter chatWithWenxinAi(ChatQueryRequest queryRequest, SseEmitter ss * @throws IOException */ private SseEmitter chatWithClaudeAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException { - String prompt = buildPrompt(queryRequest); + String prompt = promptService.buildPrompt(queryRequest); ClaudeChatMessage claudeChatMessage = new ClaudeChatMessage(); claudeChatMessage.setText(prompt); ClaudeChatCompletionsOptions chatCompletionsOptions = new ClaudeChatCompletionsOptions(); @@ -546,270 +543,5 @@ private SseEmitter buildSseEmitter(SseEmitter sseEmitter, String uid) throws IOE return sseEmitter; } - /** - * 构建schema参数 - * - * @param tableQueryParam - * @param tableNames - * @return - */ - private String buildTableColumn(TableQueryParam tableQueryParam, - List tableNames) { - if (CollectionUtils.isEmpty(tableNames)) { - return ""; - } - List schemaContent = Lists.newArrayList(); - try { - schemaContent = tableNames.stream().map(tableName -> { - tableQueryParam.setTableName(tableName); - return queryTableDdl(tableName, tableQueryParam); - }).collect(Collectors.toList()); - } catch (Exception exception) { - log.error("query table error, do nothing"); - } - - return JSON.toJSONString(schemaContent); - } - - /** - * query table schema - * - * @param tableName - * @param request - * @return - */ - private String queryTableDdl(String tableName, TableQueryParam request) { - ShowCreateTableParam param = new ShowCreateTableParam(); - param.setTableName(tableName); - param.setDataSourceId(request.getDataSourceId()); - param.setDatabaseName(request.getDatabaseName()); - param.setSchemaName(request.getSchemaName()); - DataResult tableSchema = tableService.showCreateTable(param); - return tableSchema.getData(); - } - - /** - * 构建prompt - * - * @param queryRequest - * @return - */ - private String buildPrompt(ChatQueryRequest queryRequest) { - if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { - return queryRequest.getMessage(); - } - - // 查询schema信息 - String dataSourceType = queryDatabaseType(queryRequest); - String properties = ""; - if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { - TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); - properties = buildTableColumn(queryParam, queryRequest.getTableNames()); - } else { - properties = mappingDatabaseSchema(queryRequest); - } - String prompt = queryRequest.getMessage(); - String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() - : queryRequest.getPromptType(); - PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); - String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; - String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( - "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# " - + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, - properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", - pType.getDescription(), ext, prompt); - switch (pType) { - case SQL_2_SQL: - schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( - "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); - default: - break; - } - String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); - return cleanedInput; - } - - /** - * query chat2db apikey - * - * @return - */ - public String getApiKey() { - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); - String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); - // only sync for chat2db ai - if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) { - return null; - } - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return null; - } - return keyConfig.getContent(); - } - - /** - * query database type - * - * @param queryRequest - * @return - */ - public String queryDatabaseType(ChatQueryRequest queryRequest) { - // 查询schema信息 - DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId()); - String dataSourceType = dataResult.getData().getType(); - if (StringUtils.isBlank(dataSourceType)) { - dataSourceType = "MYSQL"; - } - return dataSourceType; - } - - public String mappingDatabaseSchema(ChatQueryRequest queryRequest) { - String properties = ""; - String apiKey = getApiKey(); - if (StringUtils.isNotBlank(apiKey)) { - boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData(); - if (res) { -// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest); - properties = queryDatabaseSchema(queryRequest); - } - } - return properties; - } - - /** - * query database schema - * - * @param queryRequest - * @return - * @throws IOException - */ - public String queryDatabaseSchema(ChatQueryRequest queryRequest) { - // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); - List> contentVector = new ArrayList<>(); - if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) { - return ""; - } - contentVector.add(response.getData().get(0).getEmbedding()); - - // search embedding - TableSchemaRequest tableSchemaRequest = new TableSchemaRequest(); - tableSchemaRequest.setSchemaVector(contentVector); - tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); - tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); - tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName()); - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return ""; - } - tableSchemaRequest.setApiKey(keyConfig.getContent()); - try { - DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest); - List schemas = Lists.newArrayList(); - if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { - for(TableSchema data: result.getData().getTableSchemas()){ - schemas.add(data.getTableSchema()); - } - } - if (CollectionUtils.isEmpty(schemas)) { - return ""; - } - String res = JSON.toJSONString(schemas); - log.info("search vector result:{}", res); - return res; - } catch (Exception exception) { - log.error("query table error, do nothing"); - return ""; - } - } - - /** - * query database schema - * - * @param queryRequest - * @return - * @throws IOException - */ - public String querySchemaByEs(ChatQueryRequest queryRequest) { - // search embedding - EsTableSchemaRequest tableSchemaRequest = new EsTableSchemaRequest(); - tableSchemaRequest.setSearchKey(queryRequest.getMessage()); - tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); - tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); - tableSchemaRequest.setSchemaName(queryRequest.getSchemaName()); - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); - if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { - return ""; - } - tableSchemaRequest.setApiKey(keyConfig.getContent()); - try { - DataResult result = gatewayClientService.schemaEsSearch(tableSchemaRequest); - List schemas = Lists.newArrayList(); - if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { - for(EsTableSchema data: result.getData().getTableSchemas()){ - schemas.add(data.getTableSchemaContent()); - } - } - if (CollectionUtils.isEmpty(schemas)) { - return ""; - } - String res = JSON.toJSONString(schemas); - log.info("search es result:{}", res); - return res; - } catch (Exception exception) { - log.error("query es table error, do nothing"); - return ""; - } - } - - /** - * distribute embedding with different AI - * - * @return - */ - public FastChatEmbeddingResponse distributeAIEmbedding(String input) { - ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); - Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); - String aiSqlSource = config.getContent(); - if (Objects.isNull(aiSqlSource)) { - return null; - } - AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource); - switch (Objects.requireNonNull(aiSqlSourceEnum)) { - case CHAT2DBAI: - return embeddingWithChat2dbAi(input); - case FASTCHATAI: - return embeddingWithFastChatAi(input); - } - return null; - } - - /** - * embedding with fast chat openai - * - * @param input - * @return - * @throws IOException - */ - private FastChatEmbeddingResponse embeddingWithFastChatAi(String input) { - FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input); - return response; - } - - /** - * embedding with open ai - * - * @param input - * @return - */ - private FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) { - FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input); - return embeddings; - } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java index c8c694309..70df31c6f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/EmbeddingController.java @@ -242,7 +242,7 @@ public void syncTableVector(TableBriefQueryRequest param) throws Exception { return; } - String apiKey = getApiKey(); + String apiKey = promptService.getApiKey(); if (StringUtils.isBlank(apiKey)) { return; } @@ -281,7 +281,7 @@ private void saveTableEmbedding(String tableSchema, TableSchemaRequest tableSche List> contentVector = new ArrayList<>(); for(String str : schemaList){ // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(str); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str); if(response == null){ throw new ParamBusinessException(); } @@ -310,7 +310,7 @@ public void syncTableEs(TableBriefQueryRequest param) throws Exception { return; } - String apiKey = getApiKey(); + String apiKey = promptService.getApiKey(); if (StringUtils.isBlank(apiKey)) { return; } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java index 6ff16ee09..6ef0731ac 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/KnowledgeController.java @@ -70,7 +70,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request) contentWordCount.add(str.length()); // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(str); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(str); if(response == null){ continue; } @@ -97,7 +97,7 @@ public ActionResult embeddings(MultipartFile file, HttpServletRequest request) public SseEmitter search(ChatQueryRequest queryRequest, @RequestHeader Map headers) throws Exception { // request embedding - FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); + FastChatEmbeddingResponse response = promptService.distributeAIEmbedding(queryRequest.getMessage()); List> contentVector = new ArrayList<>(); contentVector.add(response.getData().get(0).getEmbedding()); diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java index 0c6180667..94caf7d4f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/TextGenerationController.java @@ -63,8 +63,8 @@ public SseEmitter prompt(ChatQueryRequest queryRequest, @RequestHeader Map chatMessages, EventSourceListener eventSourceListener) { - if (CollectionUtils.isEmpty(chatMessages)) { - log.error("param error:Azure Prompt cannot be empty"); - throw new ParamBusinessException("prompt"); - } + public void streamCompletions(AzureChatCompletionsOptions chatCompletionsOptions, EventSourceListener eventSourceListener) { if (Objects.isNull(eventSourceListener)) { log.error("param error:AzureEventSourceListener cannot be empty"); throw new ParamBusinessException(); } - log.info("Azure Open AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); try { - - AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(chatMessages); chatCompletionsOptions.setStream(true); chatCompletionsOptions.setModel(this.deployId); - EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); @@ -172,7 +164,7 @@ public void streamCompletions(List chatMessages, EventSourceLi if (!endpoint.endsWith("/")) { endpoint = endpoint + "/"; } - String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2023-05-15"; + String url = this.endpoint + "openai/deployments/"+ deployId + "/chat/completions?api-version=2024-02-15-preview"; Request request = new Request.Builder() .url(url) .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java index 4488bd6b8..2b9ab4a99 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/listener/AzureOpenAIEventSourceListener.java @@ -1,15 +1,32 @@ package ai.chat2db.server.web.api.controller.ai.azure.listener; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.web.api.controller.ai.azure.client.AzureOpenAIClient; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatChoice; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletions; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatCompletionsOptions; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatMessage; +import ai.chat2db.server.web.api.controller.ai.azure.model.AzureChatRole; import ai.chat2db.server.web.api.controller.ai.azure.model.AzureCompletionsUsage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; +import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; +import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; + import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; + import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import okhttp3.Response; @@ -26,123 +43,26 @@ * @date 2023-02-22 */ @Slf4j -public class AzureOpenAIEventSourceListener extends EventSourceListener { +public class AzureOpenAIEventSourceListener extends OpenAIEventSourceListener { - private SseEmitter sseEmitter; - private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); - - public AzureOpenAIEventSourceListener(SseEmitter sseEmitter) { - this.sseEmitter = sseEmitter; + public AzureOpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, + ChatQueryRequest queryRequest, LoginUser loginUser) { + super(sseEmitter, promptService, queryRequest, loginUser); } - /** - * {@inheritDoc} - */ @Override - public void onOpen(EventSource eventSource, Response response) { - log.info("AzureOpenAI建立sse连接..."); + public String getName(){ + return "AzureOpenAI"; } - /** - * {@inheritDoc} - */ - @SneakyThrows - @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("AzureOpenAI返回数据:{}", data); - if (data.equals("[DONE]")) { - log.info("AzureOpenAI返回数据结束了"); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); - return; - } - - AzureChatCompletions chatCompletions = mapper.readValue(data, AzureChatCompletions.class); - String text = ""; - log.info("Model ID={} is created at {}.", chatCompletions.getId(), - chatCompletions.getCreated()); - for (AzureChatChoice choice : chatCompletions.getChoices()) { - AzureChatMessage message = choice.getDelta(); - if (message != null) { - log.info("Index: {}, Chat Role: {}", choice.getIndex(), message.getRole()); - if (message.getContent() != null) { - text = message.getContent(); - } - } - } - - AzureCompletionsUsage usage = chatCompletions.getUsage(); - if (usage != null) { - log.info( - "Usage: number of prompt token is {}, number of completion token is {}, and number of total " - + "tokens in request and response is {}.%n", usage.getPromptTokens(), - usage.getCompletionTokens(), usage.getTotalTokens()); - } - - Message message = new Message(); - message.setContent(text); - sseEmitter.send(SseEmitter.event() - .id(null) - .data(message) - .reconnectTime(3000)); - } - - @Override - public void onClosed(EventSource eventSource) { - try { - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - } catch (IOException e) { - throw new RuntimeException(e); - } - sseEmitter.complete(); - log.info("AzureOpenAI close sse connection..."); - } - - @Override - public void onFailure(EventSource eventSource, Throwable t, Response response) { - try { - if (Objects.isNull(response)) { - String message = t.getMessage(); - Message sseMessage = new Message(); - sseMessage.setContent(message); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - return; - } - ResponseBody body = response.body(); - String bodyString = Objects.nonNull(t) ? t.getMessage() : ""; - if (Objects.nonNull(body)) { - bodyString = body.string(); - if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) { - bodyString = t.getMessage(); - } - log.error("Azure OpenAI sse response:{}", bodyString); - } else { - log.error("Azure OpenAI sse response:{},error:{}", response, t); - } - eventSource.cancel(); - Message message = new Message(); - message.setContent("Azure OpenAI error:" + bodyString); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - } catch (Exception exception) { - log.error("Azure OpenAI发送数据异常:", exception); - } + @Override + public void functionCall(String prompt){ + AzureChatMessage currentMessage = new AzureChatMessage(AzureChatRole.USER).setContent(prompt); + List messages = new ArrayList<>(); + messages.add(currentMessage); + AzureChatCompletionsOptions chatCompletionsOptions = new AzureChatCompletionsOptions(messages); + chatCompletionsOptions.setStream(true); + AzureOpenAIClient.getInstance().streamCompletions(chatCompletionsOptions, this); } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java index 1d6198e57..8d33166b3 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/azure/model/AzureChatCompletionsOptions.java @@ -7,6 +7,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.tool.Tools; + import lombok.Data; /** @@ -391,4 +393,11 @@ public AzureChatCompletionsOptions setModel(String model) { this.model = model; return this; } + + // 新添加的参数 + @JsonProperty(value = "tool_choice") + private String toolChoice; // 工具选择策略 + + @JsonProperty(value = "tools") + private List tools; // 工具列表 } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java index 0f0b6d84f..295d39cff 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/chat2db/client/Chat2DBAIStreamClient.java @@ -1,21 +1,15 @@ package ai.chat2db.server.web.api.controller.ai.chat2db.client; -import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; -import ai.chat2db.server.domain.api.model.Config; -import ai.chat2db.server.domain.api.service.ConfigService; -import ai.chat2db.server.tools.base.wrapper.result.DataResult; import ai.chat2db.server.tools.common.exception.ParamBusinessException; import ai.chat2db.server.web.api.controller.ai.chat2db.interceptor.Chat2dbHeaderAuthorizationInterceptor; import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatOpenAiApi; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbedding; import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; -import ai.chat2db.server.web.api.util.ApplicationContextUtil; import cn.hutool.http.ContentType; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.ChatCompletion; import com.unfbx.chatgpt.entity.chat.Message; -import com.unfbx.chatgpt.interceptor.HeaderAuthorizationInterceptor; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import okhttp3.*; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java index 9e9745c75..f6e833a01 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/enums/PromptType.java @@ -38,6 +38,12 @@ public enum PromptType implements BaseEnum { * text generation */ TEXT_GENERATION("文本生成"), + + + /** + * GET_TABLE_COLUMNS + */ + GET_TABLE_COLUMNS("获取指定表的属性"), ; final String description; diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java index 9ebf711c2..1d3de3bc7 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/client/OpenAIClient.java @@ -4,6 +4,7 @@ import java.net.InetSocketAddress; import java.net.Proxy; import java.util.Objects; +import java.util.concurrent.TimeUnit; import ai.chat2db.server.domain.api.model.Config; import ai.chat2db.server.domain.api.service.ConfigService; @@ -93,7 +94,17 @@ public static void refresh() { log.info("refresh openai apikey:{}", maskApiKey(apikey)); if (Objects.nonNull(host) && Objects.nonNull(port)) { Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port)); - OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build(); + OkHttpClient okHttpClient = new OkHttpClient.Builder() + // 设置连接超时为10秒 + .connectTimeout(10, TimeUnit.SECONDS) + // 设置读取超时为30秒 + .readTimeout(30, TimeUnit.SECONDS) + // 设置写入超时为15秒 + .writeTimeout(15, TimeUnit.SECONDS) + // 设置整个调用的超时为1分钟 + .callTimeout(1, TimeUnit.MINUTES) + .proxy(proxy) + .build(); OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey( Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build(); } else { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java index ccadf6d68..54609e72f 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/openai/listener/OpenAIEventSourceListener.java @@ -1,20 +1,39 @@ package ai.chat2db.server.web.api.controller.ai.openai.listener; -import java.util.Objects; - +import ai.chat2db.server.domain.repository.Dbutils; +import ai.chat2db.server.tools.common.model.Context; +import ai.chat2db.server.tools.common.model.LoginUser; +import ai.chat2db.server.tools.common.util.ContextUtils; +import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; +import com.alibaba.fastjson2.JSONArray; +import com.alibaba.fastjson2.JSONObject; import com.fasterxml.jackson.databind.DeserializationFeature; import com.fasterxml.jackson.databind.ObjectMapper; import com.unfbx.chatgpt.entity.chat.Message; +import com.unfbx.chatgpt.entity.chat.tool.ToolCallFunction; +import com.unfbx.chatgpt.entity.chat.tool.ToolCalls; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import okhttp3.Response; import okhttp3.ResponseBody; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; + +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Set; +import java.util.List; +import java.util.Objects; + /** * 描述:OpenAIEventSourceListener * @@ -24,61 +43,203 @@ @Slf4j public class OpenAIEventSourceListener extends EventSourceListener { - private SseEmitter sseEmitter; + private final SseEmitter sseEmitter; + + protected final PromptService promptService;; + + private final ChatQueryRequest queryRequest; - public OpenAIEventSourceListener(SseEmitter sseEmitter) { + public final LoginUser loginUser; + + private List toolCalls = new ArrayList<>(); + + + public OpenAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, ChatQueryRequest queryRequest, LoginUser loginUser) { this.sseEmitter = sseEmitter; + this.promptService = promptService; + this.queryRequest = queryRequest; + this.loginUser = loginUser; + } + + public static List mergeToolCallsLists(List list1, List list2) { + List mergedList = new ArrayList<>(list1); + if (list2.isEmpty()) { + return mergedList; + } + ToolCalls item2 = list2.get(0); + boolean isMerged = false; + // 反向遍历 + for (int i = list1.size() - 1; i >= 0; i--) { + ToolCalls item1 = list1.get(i); + if (item2.getId() == null || Objects.equals(item1.getId(), item2.getId())) { + mergedList.set(i, mergeToolCalls(item1, item2)); + isMerged = true; + break; + } + } + if (!isMerged) { + // 如果 list2 中的对象与 list1 中的任何对象都不匹配,则作为新对象添加 + mergedList.add(item2); + } + return mergedList; + } + + private static ToolCalls mergeToolCalls(ToolCalls tc1, ToolCalls tc2) { + if (tc1 == null) return tc2; + if (tc2 == null) return tc1; + + // 相同的逻辑,只是当 id 为 null 时进行合并 + String id = tc1.getId() != null ? tc1.getId() : tc2.getId(); + String type = mergeStrings(tc1.getType(), tc2.getType()); + ToolCallFunction function = mergeToolCallFunctions(tc1.getFunction(), tc2.getFunction()); + + return new ToolCalls(id, type, function); + } + + private static ToolCallFunction mergeToolCallFunctions(ToolCallFunction f1, ToolCallFunction f2) { + if (f1 == null) return f2; + if (f2 == null) return f1; + + String name = mergeStrings(f1.getName(), f2.getName()); + String arguments = mergeStrings(f1.getArguments(), f2.getArguments()); + + return new ToolCallFunction(name, arguments); + } + + private static String mergeStrings(String str1, String str2) { + if (str1 != null && str2 != null) { + // Concatenate both strings + return str1 + str2; + } else if (str1 != null) { + return str1; + } else { + return str2; + } } + + public String getName() { + return "OpenAI"; + } /** * {@inheritDoc} */ @Override public void onOpen(EventSource eventSource, Response response) { - log.info("OpenAI建立sse连接..."); + log.info("{}建立sse连接...",getName()); } + + public void functionCall(String prompt){ + List messages = new ArrayList<>(); + Message currentMessage = Message.builder().content(prompt).role(Message.Role.USER).build(); + messages.add(currentMessage); + OpenAIClient.getInstance().streamChatCompletion(messages, this); + } + + + public void handleTableNames(Set tableNames, Object instance) { + if (instance instanceof JSONArray) { + ((JSONArray) instance).forEach(item -> handleTableNames(tableNames, item)); + } else if (instance instanceof JSONObject) { + ((JSONObject) instance).forEach((key, value) -> handleTableNames(tableNames, value)); + } else if (instance instanceof String) { + String tableName = (String) instance; + List queryTableNames = queryRequest.getTableNames(); + if (queryTableNames != null) { + String mostSimilarTableName = queryTableNames.stream() + // 根据相似度排序 + .min(Comparator.comparingInt(existingTableName -> StringUtils.getLevenshteinDistance(existingTableName, tableName))) + .orElse(tableName); + tableNames.add(mostSimilarTableName); + }else{ + tableNames.add(tableName); + } + + } + } /** * {@inheritDoc} */ @SneakyThrows @Override public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("OpenAI返回数据:{}", data); + String scheme = getName(); + log.info("{}返回数据:{}",scheme,data); if (data.equals("[DONE]")) { - log.info("OpenAI返回数据结束了"); + if (toolCalls.isEmpty()) { + log.info("{}返回数据结束了",scheme); + sseEmitter.send(SseEmitter.event() + .id("[DONE]") + .data("[DONE]") + .reconnectTime(3000)); + sseEmitter.complete(); + return; + } + Set tableNames = new HashSet<>(); + for (ToolCalls toolCall : toolCalls) { + String callId = toolCall.getId(); + ToolCallFunction function = toolCall.getFunction(); + if (function != null && Objects.nonNull(function.getArguments())) { + String functionName = function.getName(); + if ("get_table_columns".equals(functionName)) { + JSONObject arguments = JSONObject.parse(function.getArguments()); + handleTableNames(tableNames,arguments.get("table_names")); + log.info("原始参数:{},处理后:{}",arguments,tableNames); + } + } + } + Message message = new Message(); + message.setContent("选择表" + tableNames+"\n"); sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); + .data(message) + .reconnectTime(3000)); + queryRequest.setTableNames(new ArrayList<>(tableNames)); + ContextUtils.setContext(Context.builder() + .loginUser(loginUser) + .build()); + Dbutils.setSession(); + String prompt = promptService.buildPrompt(queryRequest); + Dbutils.removeSession(); + prompt = prompt.replaceAll("#", ""); + log.info("{} 新提示词 :{}",scheme,prompt); + functionCall(prompt); + toolCalls.clear(); return; } ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); // 读取Json ChatCompletionResponse completionResponse = mapper.readValue(data, ChatCompletionResponse.class); - String text = completionResponse.getChoices().get(0).getDelta() == null - ? completionResponse.getChoices().get(0).getText() - : completionResponse.getChoices().get(0).getDelta().getContent(); + if(CollectionUtils.isEmpty(completionResponse.getChoices())){ + return; + } + Message delta = completionResponse.getChoices().get(0).getDelta(); + if (delta != null && delta.getToolCalls() != null) { + this.toolCalls = mergeToolCallsLists(this.toolCalls, delta.getToolCalls()); + } + String text = delta == null + ? completionResponse.getChoices().get(0).getText() + : delta.getContent(); Message message = new Message(); if (text != null) { message.setContent(text); sseEmitter.send(SseEmitter.event() - .id(completionResponse.getId()) - .data(message) - .reconnectTime(3000)); + .id(completionResponse.getId()) + .data(message) + .reconnectTime(3000)); } } @Override public void onClosed(EventSource eventSource) { - sseEmitter.complete(); - log.info("OpenAI关闭sse连接..."); +// sseEmitter.complete(); +// log.info("OpenAI关闭sse连接..."); } @Override public void onFailure(EventSource eventSource, Throwable t, Response response) { + String scheme = getName(); try { if (Objects.isNull(response)) { String message = t.getMessage(); @@ -88,11 +249,11 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { Message sseMessage = new Message(); sseMessage.setContent(message); sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); + .id("[ERROR]") + .data(sseMessage)); sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); + .id("[DONE]") + .data("[DONE]")); sseEmitter.complete(); return; } @@ -100,22 +261,22 @@ public void onFailure(EventSource eventSource, Throwable t, Response response) { String bodyString = null; if (Objects.nonNull(body)) { bodyString = body.string(); - log.error("OpenAI sse连接异常data:{}", bodyString, t); + log.error("{} sse连接异常data:{}",scheme, bodyString, t); } else { - log.error("OpenAI sse连接异常data:{}", response, t); + log.error("{} sse连接异常data:{}",scheme, response, t); } eventSource.cancel(); Message message = new Message(); message.setContent("出现异常,请在帮助中查看详细日志:" + bodyString); sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); + .id("[ERROR]") + .data(message)); sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); + .id("[DONE]") + .data("[DONE]")); sseEmitter.complete(); } catch (Exception exception) { - log.error("发送数据异常:", exception); + log.error("{}发送数据异常:", scheme,exception); } } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java new file mode 100644 index 000000000..503c5543a --- /dev/null +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/utils/PromptService.java @@ -0,0 +1,481 @@ +package ai.chat2db.server.web.api.controller.ai.utils; + +import java.io.IOException; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import org.apache.commons.collections.MapUtils; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.poi.ss.formula.functions.T; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +import com.alibaba.fastjson2.JSON; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.unfbx.chatgpt.entity.chat.Parameters; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; + +import ai.chat2db.server.domain.api.enums.AiSqlSourceEnum; +import ai.chat2db.server.domain.api.model.Config; +import ai.chat2db.server.domain.api.model.DataSource; +import ai.chat2db.server.domain.api.param.ShowCreateTableParam; +import ai.chat2db.server.domain.api.param.TablePageQueryParam; +import ai.chat2db.server.domain.api.param.TableQueryParam; +import ai.chat2db.server.domain.api.param.TableSelector; +import ai.chat2db.server.domain.api.service.ConfigService; +import ai.chat2db.server.domain.api.service.DataSourceService; +import ai.chat2db.server.domain.api.service.TableService; +import ai.chat2db.server.tools.base.enums.WhiteListTypeEnum; +import ai.chat2db.server.tools.base.wrapper.result.DataResult; +import ai.chat2db.server.tools.base.wrapper.result.PageResult; +import ai.chat2db.server.tools.common.util.EasyEnumUtils; +import ai.chat2db.server.web.api.aspect.ConnectionInfoAspect; +import ai.chat2db.server.web.api.controller.ai.chat2db.client.Chat2dbAIClient; +import ai.chat2db.server.web.api.controller.ai.config.LocalCache; +import ai.chat2db.server.web.api.controller.ai.converter.ChatConverter; +import ai.chat2db.server.web.api.controller.ai.enums.PromptType; +import ai.chat2db.server.web.api.controller.ai.fastchat.client.FastChatAIClient; +import ai.chat2db.server.web.api.controller.ai.fastchat.embeddings.FastChatEmbeddingResponse; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.rest.client.RestAIClient; +import ai.chat2db.server.web.api.controller.rdb.converter.RdbWebConverter; +import ai.chat2db.server.web.api.http.GatewayClientService; +import ai.chat2db.server.web.api.http.model.TableSchema; +import ai.chat2db.server.web.api.http.request.TableSchemaRequest; +import ai.chat2db.server.web.api.http.request.WhiteListRequest; +import ai.chat2db.server.web.api.http.response.TableSchemaResponse; +import ai.chat2db.server.web.api.util.ApplicationContextUtil; +import ai.chat2db.spi.MetaData; +import ai.chat2db.spi.model.Table; +import ai.chat2db.spi.model.TableColumn; +import ai.chat2db.spi.sql.Chat2DBContext; +import jakarta.annotation.Resource; +import lombok.extern.slf4j.Slf4j; + + +@Slf4j +@ConnectionInfoAspect +@Service +public class PromptService { + + + @Value("${chatgpt.context.length}") + private Integer contextLength; + + + @Autowired + private TableService tableService; + + @Autowired + private DataSourceService dataSourceService; + + + @Autowired + private ChatConverter chatConverter; + + + @Resource + private GatewayClientService gatewayClientService; + + + @Autowired + private RdbWebConverter rdbWebConverter; + + + /** + * 构建prompt + * + * @param queryRequest + * @return + */ + public String buildPrompt(ChatQueryRequest queryRequest) { + if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { + return queryRequest.getMessage(); + } + + // 查询schema信息 + String dataSourceType = queryDatabaseType(queryRequest); + String properties = ""; + if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { + TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); + properties = buildTableColumn(queryParam, queryRequest.getTableNames()); + } else { + properties = mappingDatabaseSchema(queryRequest); + } + String prompt = queryRequest.getMessage(); + String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() + : queryRequest.getPromptType(); + PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; + String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( + "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables, with their properties:\n#\n# " + + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, + properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", + pType.getDescription(), ext, prompt); + switch (pType) { + case SQL_2_SQL: + schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); + default: + break; + } + String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); + return cleanedInput; + } + + public String mappingDatabaseSchema(ChatQueryRequest queryRequest) { + String properties = ""; + String apiKey = getApiKey(); + if (StringUtils.isNotBlank(apiKey)) { + boolean res = gatewayClientService.checkInWhite(new WhiteListRequest(apiKey, WhiteListTypeEnum.VECTOR.getCode())).getData(); + if (res) { +// properties = queryDatabaseSchema(queryRequest) + querySchemaByEs(queryRequest); + properties = queryDatabaseSchema(queryRequest); + } + } + return properties; + } + + + /** + * query chat2db apikey + * + * @return + */ + public String getApiKey() { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + String aiSqlSource = AiSqlSourceEnum.CHAT2DBAI.getCode(); + // only sync for chat2db ai + if (Objects.isNull(config) || !aiSqlSource.equals(config.getContent())) { + return null; + } + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return null; + } + return keyConfig.getContent(); + } + + /** + * 构建schema参数 + * + * @param tableQueryParam + * @param tableNames + * @return + */ + public String buildTableColumn(TableQueryParam tableQueryParam, + List tableNames) { + if (CollectionUtils.isEmpty(tableNames)) { + return ""; + } + try { + return tableNames.stream().map(tableName -> { + tableQueryParam.setTableName(tableName); + return queryTableDdl(tableName, tableQueryParam); + }).collect(Collectors.joining(";\n")); + } catch (Exception exception) { + log.error("query table error, do nothing"); + } + + return ""; + } + + /** + * query table schema + * + * @param tableName + * @param request + * @return + */ + public String queryTableDdl(String tableName, TableQueryParam request) { + ShowCreateTableParam param = new ShowCreateTableParam(); + param.setTableName(tableName); + param.setDataSourceId(request.getDataSourceId()); + param.setDatabaseName(request.getDatabaseName()); + param.setSchemaName(request.getSchemaName()); + DataResult tableSchema = tableService.showCreateTable(param); + return tableSchema.getData(); + } + + /** + * query database schema + * + * @param queryRequest + * @return + * @throws IOException + */ + public String queryDatabaseSchema(ChatQueryRequest queryRequest) { + // request embedding + FastChatEmbeddingResponse response = distributeAIEmbedding(queryRequest.getMessage()); + List> contentVector = new ArrayList<>(); + if (Objects.isNull(response) || CollectionUtils.isEmpty(response.getData())) { + return ""; + } + contentVector.add(response.getData().get(0).getEmbedding()); + + // search embedding + TableSchemaRequest tableSchemaRequest = new TableSchemaRequest(); + tableSchemaRequest.setSchemaVector(contentVector); + tableSchemaRequest.setDataSourceId(queryRequest.getDataSourceId()); + tableSchemaRequest.setDatabaseName(queryRequest.getDatabaseName()); + tableSchemaRequest.setDataSourceSchema(queryRequest.getSchemaName()); + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config keyConfig = configService.find(Chat2dbAIClient.CHAT2DB_OPENAI_KEY).getData(); + if (Objects.isNull(keyConfig) || StringUtils.isBlank(keyConfig.getContent())) { + return ""; + } + tableSchemaRequest.setApiKey(keyConfig.getContent()); + try { + DataResult result = gatewayClientService.schemaVectorSearch(tableSchemaRequest); + List schemas = Lists.newArrayList(); + if (Objects.nonNull(result.getData()) && CollectionUtils.isNotEmpty(result.getData().getTableSchemas())) { + for(TableSchema data: result.getData().getTableSchemas()){ + schemas.add(data.getTableSchema()); + } + } + if (CollectionUtils.isEmpty(schemas)) { + return ""; + } + String res = JSON.toJSONString(schemas); + log.info("search vector result:{}", res); + return res; + } catch (Exception exception) { + log.error("query table error, do nothing"); + return ""; + } + } + + /** + * distribute embedding with different AI + * + * @return + */ + public FastChatEmbeddingResponse distributeAIEmbedding(String input) { + ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); + Config config = configService.find(RestAIClient.AI_SQL_SOURCE).getData(); + String aiSqlSource = config.getContent(); + if (Objects.isNull(aiSqlSource)) { + return null; + } + AiSqlSourceEnum aiSqlSourceEnum = AiSqlSourceEnum.getByName(aiSqlSource); + switch (Objects.requireNonNull(aiSqlSourceEnum)) { + case CHAT2DBAI: + return embeddingWithChat2dbAi(input); + case FASTCHATAI: + return embeddingWithFastChatAi(input); + } + return null; + } + + /** + * embedding with fast chat openai + * + * @param input + * @return + * @throws IOException + */ + public FastChatEmbeddingResponse embeddingWithFastChatAi(String input) { + FastChatEmbeddingResponse response = FastChatAIClient.getInstance().embeddings(input); + return response; + } + + /** + * embedding with open ai + * + * @param input + * @return + */ + public FastChatEmbeddingResponse embeddingWithChat2dbAi(String input) { + FastChatEmbeddingResponse embeddings = Chat2dbAIClient.getInstance().embeddings(input); + return embeddings; + } + + /** + * 构建prompt + * + * @param queryRequest + * @return + */ + public String buildAutoPrompt(ChatQueryRequest queryRequest) { + if (PromptType.TEXT_GENERATION.getCode().equals(queryRequest.getPromptType())) { + return queryRequest.getMessage(); + } + // 查询schema信息 + String dataSourceType = queryDatabaseType(queryRequest); + String properties = ""; + if (CollectionUtils.isNotEmpty(queryRequest.getTableNames())) { + TableQueryParam queryParam = chatConverter.chat2tableQuery(queryRequest); + properties = buildTableColumn(queryParam, queryRequest.getTableNames()); + } else { + properties = queryDatabaseTables(queryRequest); + } + String prompt = queryRequest.getMessage(); + String promptType = StringUtils.isBlank(queryRequest.getPromptType()) ? PromptType.NL_2_SQL.getCode() + : queryRequest.getPromptType(); + PromptType pType = EasyEnumUtils.getEnum(PromptType.class, promptType); + if (StringUtils.isNotEmpty(properties)) { + pType = PromptType.GET_TABLE_COLUMNS; + } + String ext = StringUtils.isNotBlank(queryRequest.getExt()) ? queryRequest.getExt() : ""; + String schemaProperty = StringUtils.isNotEmpty(properties) ? String.format( + "### 请根据以下table properties和SQL input%s. %s\n#\n### %s SQL tables:\n#\n# " + + "%s\n#\n#\n### SQL input: %s", pType.getDescription(), ext, dataSourceType, + properties, prompt) : String.format("### 请根据以下SQL input%s. %s\n#\n### SQL input: %s", + pType.getDescription(), ext, prompt); + switch (pType) { + case SQL_2_SQL: + schemaProperty = StringUtils.isNotBlank(queryRequest.getDestSqlType()) ? String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, queryRequest.getDestSqlType()) : String.format( + "%s\n#\n### 目标SQL类型: %s", schemaProperty, dataSourceType); + default: + break; + } + String cleanedInput = schemaProperty.replaceAll("[\r\t]", ""); + return cleanedInput; + } + + + /** + * query database type + * + * @param queryRequest + * @return + */ + public String queryDatabaseType(ChatQueryRequest queryRequest) { + // 查询schema信息 + DataResult dataResult = dataSourceService.queryById(queryRequest.getDataSourceId()); + String dataSourceType = dataResult.getData().getType(); + if (StringUtils.isBlank(dataSourceType)) { + dataSourceType = "MYSQL"; + } + return dataSourceType; + } + + + /** + * 根据给定的表对象找出所有可能的外键列 + * @return 外键列名列表 + */ + public static List findPossibleForeignKeys(List columns) { + List foreignKeys = new ArrayList<>(); + for (TableColumn column : columns) { + String columnName = column.getName(); + // 假设TableColumn类有一个getTableName方法可以获取列所属的表名 + String tableName = column.getTableName(); + Boolean primaryKey = column.getPrimaryKey(); + + // 检查列名是否符合`关联表_id`的格式,并且列名前半部分不等于表名 + if (columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey)) { + // 从列名中移除"_id"以获取可能的关联表名 + String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3); + + if (!potentialForeignKeyTable.equals(tableName)) { + foreignKeys.add(columnName); + } + } + } + return foreignKeys; + } + + /** + * query database schema + * + * @param queryRequest + * @return + * @throws IOException + */ + public String queryDatabaseTables(ChatQueryRequest queryRequest) { + try { + TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(queryRequest); + queryParam.queryAll(); + TableSelector tableSelector = new TableSelector(); + tableSelector.setColumnList(true); + tableSelector.setIndexList(false); + PageResult
tables = tableService.pageQuery(queryParam,tableSelector); + List tableNames = new ArrayList<>(); + String properties = tables.getData().stream().map(table -> { + tableNames.add(table.getName()); + StringBuilder sb = new StringBuilder(table.getName()); // 直接在初始化时加入表名 + String comment = table.getComment(); + List columns = table.getColumnList(); + List foreignKeys = findPossibleForeignKeys(columns); + + // 只有当有注释或外键时才添加额外信息 + if(StringUtils.isNotEmpty(comment) || !foreignKeys.isEmpty()){ + sb.append("(").append(comment); + + // 如果存在外键,添加外键信息 + if(!foreignKeys.isEmpty()){ + // 如果注释和外键都存在,先添加一个分隔符 + if(StringUtils.isNotEmpty(comment)) { + sb.append("; "); + } + sb.append("外键:").append(String.join(", ", foreignKeys)); // 优化外键的展示 + } + sb.append(")"); + } + return sb.toString(); // 在映射阶段直接转换为字符串 + }) + .collect(Collectors.joining(",")); + queryRequest.setTableNames(tableNames); + return properties; + } catch (Exception e) { + log.error("query table error:{}, do nothing", e.getMessage()); + return ""; + } + } + + public static ToolsFunction getToolsFunction(){ + return ToolsFunction.builder() + .name("get_table_columns") + .description(PromptType.GET_TABLE_COLUMNS.getDescription()) + .parameters(Parameters.builder() + .type("object") + .properties(ImmutableMap.builder() + .put("table_names", ImmutableMap.builder() + .put("description", "表名,例如```User```") + .put("type", "array") + .put("items", ImmutableMap.of("type", "string")) + .put("uniqueItems", true) + .build()) + .build()) + .required(List.of("table_name")) + .build()) + .build(); + } + + + /** + * get fast chat message + * + * @param uid + * @param prompt + * @return + */ + public List getFastChatMessage(String uid, String prompt) { + List messages = (List)LocalCache.CACHE.get(uid); + if (CollectionUtils.isNotEmpty(messages)) { + if (messages.size() >= contextLength) { + messages = messages.subList(1, contextLength); + } + } else { + messages = Lists.newArrayList(); + } + FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); + messages.add(currentMessage); + return messages; + } +} diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java index f205f17f5..db0d35fa6 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIClient.java @@ -58,8 +58,8 @@ private static ZhipuChatAIStreamClient singleton() { public static void refresh() { String apiKey = ""; - String apiHost = "https://open.bigmodel.cn/api/paas/v3/model-api/"; - String model = "chatglm_turbo"; + String apiHost = "https://open.bigmodel.cn/api/paas/v4/chat/completions"; + String model = "glm-4"; ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); Config apiHostConfig = configService.find(ZHIPU_HOST).getData(); if (apiHostConfig != null && StringUtils.isNotBlank(apiHostConfig.getContent())) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java index 550c929eb..ef0ec8071 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/client/ZhipuChatAIStreamClient.java @@ -1,7 +1,6 @@ package ai.chat2db.server.web.api.controller.ai.zhipu.client; import ai.chat2db.server.tools.common.exception.ParamBusinessException; -import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import ai.chat2db.server.web.api.controller.ai.zhipu.interceptor.ZhipuChatHeaderAuthorizationInterceptor; import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import cn.hutool.http.ContentType; @@ -16,10 +15,8 @@ import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSources; -import org.apache.commons.collections4.CollectionUtils; import org.jetbrains.annotations.NotNull; -import java.util.List; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -69,7 +66,6 @@ public class ZhipuChatAIStreamClient { @Getter private OkHttpClient okHttpClient; - /** * @param builder */ @@ -90,13 +86,12 @@ private ZhipuChatAIStreamClient(Builder builder) { * okhttpclient */ private OkHttpClient okHttpClient() { - OkHttpClient okHttpClient = new OkHttpClient - .Builder() - .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret)) - .connectTimeout(10, TimeUnit.SECONDS) - .writeTimeout(50, TimeUnit.SECONDS) - .readTimeout(50, TimeUnit.SECONDS) - .build(); + OkHttpClient okHttpClient = new OkHttpClient.Builder() + .addInterceptor(new ZhipuChatHeaderAuthorizationInterceptor(this.key, this.secret)) + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(50, TimeUnit.SECONDS) + .readTimeout(50, TimeUnit.SECONDS) + .build(); return okHttpClient; } @@ -184,34 +179,26 @@ public ZhipuChatAIStreamClient build() { * @param chatMessages * @param eventSourceListener */ - public void streamCompletions(List chatMessages, EventSourceListener eventSourceListener) { - if (CollectionUtils.isEmpty(chatMessages)) { - log.error("param error:Zhipu Chat Prompt cannot be empty"); - throw new ParamBusinessException("prompt"); - } + public void streamCompletions(ZhipuChatCompletionsOptions completionsOptions, EventSourceListener eventSourceListener) { + if (Objects.isNull(eventSourceListener)) { log.error("param error:Zhipu ChatEventSourceListener cannot be empty"); throw new ParamBusinessException(); } - log.info("Zhipu Chat AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent()); + completionsOptions.setModel(this.model); try { - // 建议直接查看demo包代码,这里更新可能不及时 - ZhipuChatCompletionsOptions completionsOptions = new ZhipuChatCompletionsOptions(); - completionsOptions.setPrompt(chatMessages); - completionsOptions.setModel(this.model); - String requestId = String.valueOf(System.currentTimeMillis()); - completionsOptions.setRequestId(requestId); + ObjectMapper mapper = new ObjectMapper(); mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); String requestBody = mapper.writeValueAsString(completionsOptions); - String url = this.apiHost + "/" + this.model + "/" + "sse-invoke"; + String url = this.apiHost; EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); Request request = new Request.Builder() - .url(url) - .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) - .build(); - //创建事件 + .url(url) + .post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) + .build(); + // 创建事件 EventSource eventSource = factory.newEventSource(request, eventSourceListener); log.info("finish invoking zhipu chat ai"); } catch (Exception e) { diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java index a8b1ae016..a02668775 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/listener/ZhipuChatAIEventSourceListener.java @@ -1,22 +1,24 @@ package ai.chat2db.server.web.api.controller.ai.zhipu.listener; +import ai.chat2db.server.tools.common.model.LoginUser; import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; -import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletions; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.unfbx.chatgpt.entity.chat.Message; -import lombok.SneakyThrows; +import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatRole; +import ai.chat2db.server.web.api.controller.ai.openai.listener.OpenAIEventSourceListener; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; +import ai.chat2db.server.web.api.controller.ai.utils.PromptService; +import ai.chat2db.server.web.api.controller.ai.zhipu.client.ZhipuChatAIClient; +import ai.chat2db.server.web.api.controller.ai.zhipu.model.ZhipuChatCompletionsOptions; import lombok.extern.slf4j.Slf4j; -import okhttp3.Response; -import okhttp3.ResponseBody; -import okhttp3.sse.EventSource; -import okhttp3.sse.EventSourceListener; -import org.apache.commons.lang3.StringUtils; -import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; -import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +import com.unfbx.chatgpt.entity.chat.tool.Tools; +import com.unfbx.chatgpt.entity.chat.tool.ToolsFunction; + /** * 描述:OpenAIEventSourceListener * @@ -24,111 +26,32 @@ * @date 2023-02-22 */ @Slf4j -public class ZhipuChatAIEventSourceListener extends EventSourceListener { - - private SseEmitter sseEmitter; - - private ObjectMapper mapper = new ObjectMapper().disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); - - public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter) { - this.sseEmitter = sseEmitter; - } - - /** - * {@inheritDoc} - */ - @Override - public void onOpen(EventSource eventSource, Response response) { - log.info("Zhipu Chat Sse connecting..."); - } - - /** - * {@inheritDoc} - */ - @SneakyThrows - @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { - log.info("Zhipu Chat AI response data:{}", data); - if (data.equals("[DONE]")) { - log.info("Zhipu Chat AI closed"); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]") - .reconnectTime(3000)); - sseEmitter.complete(); - return; - } - - ZhipuChatCompletions chatCompletions = mapper.readValue(data, ZhipuChatCompletions.class); - String text = chatCompletions.getData(); - if (Objects.isNull(text)) { - for (FastChatMessage message : chatCompletions.getBody().getChoices()) { - if (message != null && message.getContent() != null) { - text = message.getContent(); - } - } - } - - Message message = new Message(); - message.setContent(text); - sseEmitter.send(SseEmitter.event() - .id(null) - .data(message) - .reconnectTime(3000)); +public class ZhipuChatAIEventSourceListener extends OpenAIEventSourceListener { + + public ZhipuChatAIEventSourceListener(SseEmitter sseEmitter, PromptService promptService, + ChatQueryRequest queryRequest, LoginUser loginUser) { + super(sseEmitter, promptService, queryRequest, loginUser); } @Override - public void onClosed(EventSource eventSource) { - try { - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - } catch (IOException e) { - throw new RuntimeException(e); - } - sseEmitter.complete(); - log.info("ZhipuChatAI close sse connection..."); + public String getName(){ + return "Zhipu"; } @Override - public void onFailure(EventSource eventSource, Throwable t, Response response) { - try { - if (Objects.isNull(response)) { - String message = t.getMessage(); - Message sseMessage = new Message(); - sseMessage.setContent(message); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(sseMessage)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - return; - } - ResponseBody body = response.body(); - String bodyString = Objects.nonNull(t) ? t.getMessage() : ""; - if (Objects.nonNull(body)) { - bodyString = body.string(); - if (StringUtils.isBlank(bodyString) && Objects.nonNull(t)) { - bodyString = t.getMessage(); - } - log.error("Zhipu Chat AI sse response:{}", bodyString); - } else { - log.error("Zhipu Chat AI sse response:{},error:{}", response, t); - } - eventSource.cancel(); - Message message = new Message(); - message.setContent("Zhipu Chat AI error:" + bodyString); - sseEmitter.send(SseEmitter.event() - .id("[ERROR]") - .data(message)); - sseEmitter.send(SseEmitter.event() - .id("[DONE]") - .data("[DONE]")); - sseEmitter.complete(); - } catch (Exception exception) { - log.error("Zhipu Chat AI send data error:", exception); - } + public void functionCall(String prompt){ + FastChatMessage currentMessage = new FastChatMessage(FastChatRole.USER).setContent(prompt); + List messages = new ArrayList<>(); + messages.add(currentMessage); + String requestId = String.valueOf(System.currentTimeMillis()); + ToolsFunction function = PromptService.getToolsFunction(); + ZhipuChatCompletionsOptions completionsOptions = ZhipuChatCompletionsOptions.builder() + .requestId(requestId) + .stream(true) + .toolChoice("auto") + .tools(List.of(new Tools(Tools.Type.FUNCTION.getName(), function))) + .messages(messages) + .build(); + ZhipuChatAIClient.getInstance().streamCompletions(completionsOptions, this); } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java index 4b6359cc2..06c16bd07 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/zhipu/model/ZhipuChatCompletionsOptions.java @@ -5,6 +5,9 @@ import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage; import com.fasterxml.jackson.annotation.JsonProperty; +import com.unfbx.chatgpt.entity.chat.tool.Tools; + +import lombok.Builder; import lombok.Data; import java.util.List; @@ -14,18 +17,16 @@ * generate text that continues from or "completes" provided prompt data. */ @Data +@Builder public final class ZhipuChatCompletionsOptions { @JsonProperty(value = "request_id") private String requestId; // sse-params - @JsonProperty(value = "incremental") + @JsonProperty(value = "stream") private Boolean stream = true; - @JsonProperty(value = "sseFormat") - private String sseFormat = "data"; - /* * The collection of context messages associated with this chat completions request. @@ -33,8 +34,8 @@ public final class ZhipuChatCompletionsOptions { * the behavior of the assistant, followed by alternating messages between the User and * Assistant roles. */ - @JsonProperty(value = "prompt") - private List prompt; + @JsonProperty(value = "messages") + private List messages; // @@ -45,4 +46,14 @@ public final class ZhipuChatCompletionsOptions { */ @JsonProperty(value = "model") private String model; + + + + // 新添加的参数 + @JsonProperty(value = "tool_choice") + private String toolChoice; // 工具选择策略 + + @JsonProperty(value = "tools") + private List tools; // 工具列表 + } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java index 131a6bf6c..f2ce64dfd 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/TableController.java @@ -18,17 +18,18 @@ import ai.chat2db.server.web.api.controller.rdb.vo.SqlVO; import ai.chat2db.server.web.api.controller.rdb.vo.TableVO; import ai.chat2db.spi.model.*; -import ai.chat2db.spi.sql.Chat2DBContext; -import ai.chat2db.spi.sql.ConnectInfo; import com.google.common.collect.Lists; import jakarta.validation.Valid; import lombok.extern.slf4j.Slf4j; + +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.stream.Collectors; @Slf4j @ConnectionInfoAspect @@ -250,4 +251,40 @@ public ActionResult delete(@Valid @RequestBody TableDeleteRequest request) { DropParam dropParam = rdbWebConverter.tableDelete2dropParam(request); return tableService.drop(dropParam); } + + + /** + * 查询ER图 + * + * @param request + * @return + */ + @GetMapping("/er-diagram") + public DataResult erDiagram(@Valid TableBriefQueryRequest request) { + TablePageQueryParam queryParam = rdbWebConverter.tablePageRequest2param(request); + TableSelector tableSelector = new TableSelector(); + tableSelector.setColumnList(true); + tableSelector.setIndexList(false); + PageResult
tableDTOPageResult = tableService.pageQuery(queryParam, tableSelector); + List entityList = tableDTOPageResult.getData().stream().map(table -> { + ErDiagram.Node entity = new ErDiagram.Node(table.getName(), + StringUtils.defaultIfBlank(table.getComment(), table.getName())); + return entity; + }).collect(Collectors.toList()); + List relationList = tableDTOPageResult.getData().stream().flatMap(table -> { + return table.getColumnList().stream().filter(column -> { + String columnName = column.getName(); + Boolean primaryKey = column.getPrimaryKey(); + return columnName != null && columnName.matches(".+_id") && Boolean.FALSE.equals(primaryKey); + }).map(column -> { + String columnName = column.getName(); + String tableName = column.getTableName(); + // 从列名中移除"_id"以获取可能的关联表名 + String potentialForeignKeyTable = columnName.substring(0, columnName.length() - 3); + ErDiagram.Edge relation = new ErDiagram.Edge(columnName,tableName, potentialForeignKeyTable,column.getComment()); + return relation; + }); + }).collect(Collectors.toList()); + return DataResult.of(new ErDiagram(entityList, relationList)); + } } diff --git a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java index b04663dc9..5a37c352a 100644 --- a/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java +++ b/chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/rdb/converter/RdbWebConverter.java @@ -3,6 +3,7 @@ import java.util.List; import ai.chat2db.server.domain.api.param.*; +import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest; import ai.chat2db.server.web.api.controller.data.source.vo.DatabaseVO; import ai.chat2db.server.web.api.controller.rdb.request.*; import ai.chat2db.server.web.api.controller.rdb.vo.ColumnVO; @@ -99,6 +100,14 @@ public abstract class RdbWebConverter { * @return */ public abstract SqlVO dto2vo(Sql dto); + + /** + * 参数转换 + * + * @param request + * @return + */ + public abstract TablePageQueryParam tablePageRequest2param(ChatQueryRequest request); /** * 参数转换 * diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java new file mode 100644 index 000000000..67a2f9920 --- /dev/null +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/model/ErDiagram.java @@ -0,0 +1,42 @@ +package ai.chat2db.spi.model; + +import java.util.List; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; + +/** + * er图 + */ +@Data +@SuperBuilder +@NoArgsConstructor +@AllArgsConstructor +public class ErDiagram { + + private List nodes; + private List edges; + + @Data + @SuperBuilder + @NoArgsConstructor + @AllArgsConstructor + public static class Node { + private String id; + private String label; + } + + @Data + @SuperBuilder + @NoArgsConstructor + @AllArgsConstructor + public static class Edge { + private String id; + private String source; + private String target; + private String label; + } + +} diff --git a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java index 9e6fce81a..88183d9ad 100644 --- a/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java +++ b/chat2db-server/chat2db-spi/src/main/java/ai/chat2db/spi/sql/Chat2DBContext.java @@ -142,6 +142,7 @@ public static void removeContext() { try { if (connection != null && !connection.isClosed()) { connection.close(); + connectInfo.setConnection(null); } } catch (SQLException e) { log.error("close connection error", e); diff --git a/chat2db-server/pom.xml b/chat2db-server/pom.xml index 16c693477..b5b0cf43a 100644 --- a/chat2db-server/pom.xml +++ b/chat2db-server/pom.xml @@ -222,7 +222,7 @@ com.unfbx chatgpt-java - 1.0.8 + 1.1.5 org.slf4j