Skip to content

Commit

Permalink
Merge branch 'chat2db:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
tmlx1990 authored Jul 10, 2024
2 parents 8625bc2 + 9b1bfad commit 640ab10
Show file tree
Hide file tree
Showing 27 changed files with 2,091 additions and 201 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import ai.chat2db.spi.SqlBuilder;
import ai.chat2db.spi.jdbc.DefaultMetaService;
import ai.chat2db.spi.model.*;
import ai.chat2db.spi.sql.Chat2DBContext;
import ai.chat2db.spi.sql.SQLExecutor;
import ai.chat2db.spi.util.SortUtils;
import com.google.common.collect.Lists;
import jakarta.validation.constraints.NotEmpty;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;

import java.sql.Connection;
Expand All @@ -20,62 +23,90 @@
import java.util.*;
import java.util.stream.Collectors;

@Slf4j
public class DMMetaData extends DefaultMetaService implements MetaData {

private List<String> systemSchemas = Arrays.asList("CTISYS", "SYS","SYSDBA","SYSSSO","SYSAUDITOR");
private List<String> systemSchemas = Arrays.asList("CTISYS", "SYS", "SYSDBA", "SYSSSO", "SYSAUDITOR");

@Override
public List<Schema> schemas(Connection connection, String databaseName) {
List<Schema> schemas = SQLExecutor.getInstance().schemas(connection, databaseName, null);
return SortUtils.sortSchema(schemas, systemSchemas);
}
private String format(String tableName){

private String format(String tableName) {
return "\"" + tableName + "\"";
}
}

private static String tableDDL = "SELECT dbms_metadata.get_ddl('TABLE', '%s','%s') as ddl FROM dual ;";
private static String tableComment = "select COMMENTS from dba_tab_comments where OWNER='%s' and TABLE_TYPE='TABLE' and TABLE_NAME='%s';";
private static String columnComment = "SELECT COLNAME,COMMENT$ FROM SYS.SYSCOLUMNCOMMENTS where SCHNAME = '%s' and TVNAME = '%s' and TABLE_TYPE = 'TABLE';";

public String tableDDL(Connection connection, String databaseName, String schemaName, String tableName) {
public String tableDDL(Connection connection, String databaseName, String schemaName, String tableName) {
String tableDDLSql = String.format(tableDDL, tableName, schemaName);
String tableCommentSql = String.format(tableComment, schemaName, tableName);
String columnCommentSql = String.format(columnComment, schemaName, tableName);
StringBuilder ddlBuilder = new StringBuilder();
SQLExecutor.getInstance().execute(connection, tableDDLSql, resultSet -> {
if (resultSet.next()) {
String ddl = resultSet.getString("ddl");
ddlBuilder.append(ddl).append("\n");
}
});
SQLExecutor.getInstance().execute(connection, tableCommentSql, resultSet -> {
if (resultSet.next()) {
String comments = resultSet.getString("COMMENTS");
if (Objects.nonNull(comments)) {
ddlBuilder.append("COMMENT ON TABLE ").append(format(schemaName)).append(".").append(format(tableName))
.append(" IS ").append(comments).append(";").append("\n");
MetaData metaData = Chat2DBContext.getMetaData();
List<Table> tables = metaData.tables(connection, databaseName, schemaName, tableName);
if (CollectionUtils.isNotEmpty(tables)) {
String tableComment = tables.get(0).getComment();
if (StringUtils.isNotBlank(tableComment)) {
ddlBuilder.append("COMMENT ON TABLE ").append(format(schemaName)).append(".").append(format(tableName))
.append(" IS '").append(tableComment.replace("'", "''")).append("'").append(";").append("\n");
}
}
List<TableColumn> columns = metaData.columns(connection, databaseName, schemaName, tableName);
if (CollectionUtils.isNotEmpty(columns)) {
for (TableColumn column : columns) {
String columnName = column.getName();
String comment = column.getComment();
if (StringUtils.isNotBlank(comment)) {
ddlBuilder.append("COMMENT ON COLUMN ").append(format(schemaName)).append(".").append(format(tableName))
.append(".").append(format(columnName)).append(" IS ")
.append("'").append(comment.replace("'", "''"))
.append("';").append("\n");
}
}
});
SQLExecutor.getInstance().execute(connection, columnCommentSql, resultSet -> {
while (resultSet.next()) {
String columnName = resultSet.getString("COLNAME");
String comment = resultSet.getString("COMMENT$");
ddlBuilder.append("COMMENT ON COLUMN ").append(format(schemaName)).append(".").append(format(tableName))
.append(".").append(format(columnName)).append(" IS ").append("'").append(comment).append("';").append("\n");
}
if (tableName.startsWith("V$")){
return ddlBuilder.toString();
}
List<TableIndex> indexes = metaData.indexes(connection, databaseName, schemaName, tableName);
if (CollectionUtils.isNotEmpty(indexes)) {
for (TableIndex index : indexes) {
String indexName = index.getName();
if (StringUtils.isNotBlank(indexName)) {
String sql = "select DBMS_METADATA.GET_DDL('INDEX','%s') as INDEX_DDL";
try {
SQLExecutor.getInstance().execute(connection, String.format(sql,indexName), resultSet -> {
if (resultSet.next()) {
ddlBuilder.append(resultSet.getString("INDEX_DDL")).append("\n");
}
});
} catch (Exception e) {
log.warn("Failed to get the DDL of the index.");
for (TableIndex tableIndex : indexes) {
DMIndexTypeEnum indexTypeEnum = DMIndexTypeEnum.getByType(tableIndex.getType());
ddlBuilder.append("\n").append(indexTypeEnum.buildIndexScript(tableIndex)).append(";");
}
}
}
}
});
}
return ddlBuilder.toString();
}

private static String ROUTINES_SQL
= "SELECT OWNER, NAME, TEXT FROM ALL_SOURCE WHERE TYPE = '%s' AND OWNER = '%s' AND NAME = '%s' ORDER BY LINE";
= "SELECT OWNER, NAME, TEXT FROM ALL_SOURCE WHERE TYPE = '%s' AND OWNER = '%s' AND NAME = '%s' ORDER BY LINE";

@Override
public Function function(Connection connection, @NotEmpty String databaseName, String schemaName,
String functionName) {
String functionName) {

String sql = String.format(ROUTINES_SQL, "PROC",schemaName, functionName);
String sql = String.format(ROUTINES_SQL, "PROC", schemaName, functionName);
return SQLExecutor.getInstance().execute(connection, sql, resultSet -> {
StringBuilder sb = new StringBuilder();
while (resultSet.next()) {
Expand All @@ -94,8 +125,8 @@ public Function function(Connection connection, @NotEmpty String databaseName, S

@Override
public Procedure procedure(Connection connection, @NotEmpty String databaseName, String schemaName,
String procedureName) {
String sql = String.format(ROUTINES_SQL, "PROC", schemaName,procedureName);
String procedureName) {
String sql = String.format(ROUTINES_SQL, "PROC", schemaName, procedureName);
return SQLExecutor.getInstance().execute(connection, sql, resultSet -> {
StringBuilder sb = new StringBuilder();
while (resultSet.next()) {
Expand All @@ -111,8 +142,8 @@ public Procedure procedure(Connection connection, @NotEmpty String databaseName,
}

private static String TRIGGER_SQL
= "SELECT OWNER, TRIGGER_NAME, TABLE_OWNER, TABLE_NAME, TRIGGERING_TYPE, TRIGGERING_EVENT, STATUS, TRIGGER_BODY "
+ "FROM ALL_TRIGGERS WHERE OWNER = '%s' AND TRIGGER_NAME = '%s'";
= "SELECT OWNER, TRIGGER_NAME, TABLE_OWNER, TABLE_NAME, TRIGGERING_TYPE, TRIGGERING_EVENT, STATUS, TRIGGER_BODY "
+ "FROM ALL_TRIGGERS WHERE OWNER = '%s' AND TRIGGER_NAME = '%s'";

private static String TRIGGER_SQL_LIST = "SELECT OWNER, TRIGGER_NAME FROM ALL_TRIGGERS WHERE OWNER = '%s'";

Expand All @@ -134,7 +165,7 @@ public List<Trigger> triggers(Connection connection, String databaseName, String

@Override
public Trigger trigger(Connection connection, @NotEmpty String databaseName, String schemaName,
String triggerName) {
String triggerName) {

String sql = String.format(TRIGGER_SQL, schemaName, triggerName);
return SQLExecutor.getInstance().execute(connection, sql, resultSet -> {
Expand All @@ -150,7 +181,7 @@ public Trigger trigger(Connection connection, @NotEmpty String databaseName, Str
}

private static String VIEW_SQL
= "SELECT OWNER, VIEW_NAME, TEXT FROM ALL_VIEWS WHERE OWNER = '%s' AND VIEW_NAME = '%s'";
= "SELECT OWNER, VIEW_NAME, TEXT FROM ALL_VIEWS WHERE OWNER = '%s' AND VIEW_NAME = '%s'";

@Override
public Table view(Connection connection, String databaseName, String schemaName, String viewName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ public void updateProcedure(Connection connection, String databaseName, String s
} catch (Exception e) {
connection.rollback();
throw new RuntimeException(e);
}finally {
connection.setAutoCommit(true);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,11 @@ public List<TableColumn> columns(Connection connection, String databaseName, Str
log.error("getDefaultValue error",e);
}
tableColumn.setName(resultSet.getString("COLUMN_NAME"));
tableColumn.setColumnType(resultSet.getString("DATA_TYPE"));
String dataType = resultSet.getString("DATA_TYPE");
if(dataType.contains("(")){
dataType = dataType.substring(0,dataType.indexOf("(")).trim();
}
tableColumn.setColumnType(dataType);
Integer dataPrecision = resultSet.getInt("DATA_PRECISION");
if(resultSet.getString("DATA_PRECISION") != null) {
tableColumn.setColumnSize(dataPrecision);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public String tableDDL(Connection connection, String databaseName, String schema
});
}

private static String SELECT_TABLES_SQL = "SELECT t.name AS TableName, mm.value as comment FROM sys.tables t LEFT JOIN(SELECT * from sys.extended_properties ep where ep.minor_id = 0 AND ep.name = 'MS_Description') mm ON t.object_id = mm.major_id WHERE t.schema_id= SCHEMA_ID('%S')";
private static String SELECT_TABLES_SQL = "SELECT t.name AS TableName, mm.value as comment FROM sys.tables t LEFT JOIN(SELECT * from sys.extended_properties ep where ep.minor_id = 0 AND ep.name = 'MS_Description') mm ON t.object_id = mm.major_id WHERE t.schema_id= SCHEMA_ID('%s')";

@Override
public List<Table> tables(Connection connection, String databaseName, String schemaName, String tableName) {
Expand Down Expand Up @@ -412,4 +412,5 @@ public List<String> getSystemDatabases() {
public List<String> getSystemSchemas() {
return systemSchemas;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,16 @@ public enum SqlServerColumnTypeEnum implements ColumnBuilder {
XML("XML", false, false, true, false, false, false, true, true),


OTHER("OTHER", false, false, true, false, false, false, true, true),
;
private ColumnType columnType;

public static SqlServerColumnTypeEnum getByType(String dataType) {
return COLUMN_TYPE_MAP.get(dataType.toUpperCase());
SqlServerColumnTypeEnum typeEnum = COLUMN_TYPE_MAP.get(dataType.toUpperCase());
if (typeEnum == null) {
return OTHER;
}
return typeEnum;
}

private static Map<String, SqlServerColumnTypeEnum> COLUMN_TYPE_MAP = Maps.newHashMap();
Expand Down Expand Up @@ -255,7 +260,9 @@ private String buildDataType(TableColumn column, SqlServerColumnTypeEnum type) {
}
return script.toString();
}

if(OTHER.equals(columnType)){
return column.getColumnType();
}

return columnType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ public DataResult<ExecuteResult> executeUpdate(DlExecuteParam param) {
List<String> sqlList = SqlUtils.parse(param.getSql(), dbType);
Connection connection = Chat2DBContext.getConnection();
try {
connection.setAutoCommit(false);
// connection.setAutoCommit(false);
for (String originalSql : sqlList) {
ExecuteResult executeResult = executor.executeUpdate(originalSql, connection, 1);
dataResult.setData(executeResult);
addOperationLog(executeResult);
}
connection.commit();
// connection.commit();
} catch (Exception e) {
log.error("executeUpdate error", e);
dataResult.setSuccess(false);
Expand Down
11 changes: 8 additions & 3 deletions chat2db-server/chat2db-spi/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,15 @@
<version>1.19.0</version> <!-- Make sure to use the latest version -->
</dependency>

<!-- https://mvnrepository.com/artifact/org.mongodb/bson -->
<dependency>
<groupId>org.mongodb</groupId>
<artifactId>bson</artifactId>
<groupId>org.antlr</groupId>
<artifactId>antlr4</artifactId>
<version>4.9.1</version>
</dependency>
<dependency>
<groupId>com.oceanbase</groupId>
<artifactId>ob-sql-parser</artifactId>
<version>1.2.1</version>
</dependency>
</dependencies>
<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ public static String getStringValue(String value) {
}
value = value.replace("\\", "\\\\");
value = value.replace("'", "\\'");
value = value.replace("\"", "\\\"");
return "'" + value + "'";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ private void close(Connection connection,Session session,SSHInfo ssh){
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
} catch (Exception e) {
}
}
if (session != null) {
try {
session.delPortForwardingL(Integer.parseInt(ssh.getLocalPort()));
} catch (JSchException e) {
} catch (Exception e) {
}
try {
session.disconnect();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ public String buildSqlByQuery(QueryResult queryResult) {
String tableName = queryResult.getTableName();
StringBuilder stringBuilder = new StringBuilder();
MetaData metaSchema = Chat2DBContext.getMetaData();
String dbType = Chat2DBContext.getDBConfig().getDbType();
List<String> keyColumns = getPrimaryColumns(headerList);
for (int i = 0; i < operations.size(); i++) {
ResultOperation operation = operations.get(i);
Expand All @@ -104,14 +105,19 @@ public String buildSqlByQuery(QueryResult queryResult) {
String sql = "";
if ("UPDATE".equalsIgnoreCase(operation.getType())) {
sql = getUpdateSql(tableName, headerList, row, odlRow, metaSchema, keyColumns, false);
if("MYSQL".equalsIgnoreCase(dbType)){
sql = sql + " LIMIT 1";
}
} else if ("CREATE".equalsIgnoreCase(operation.getType())) {
sql = getInsertSql(tableName, headerList, row, metaSchema);
} else if ("DELETE".equalsIgnoreCase(operation.getType())) {
sql = getDeleteSql(tableName, headerList, odlRow, metaSchema, keyColumns);
if("MYSQL".equalsIgnoreCase(dbType)){
sql = sql + " LIMIT 1";
}
} else if ("UPDATE_COPY".equalsIgnoreCase(operation.getType())) {
sql = getUpdateSql(tableName, headerList, row, row, metaSchema, keyColumns, true);
}

stringBuilder.append(sql + ";\n");
}
return stringBuilder.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@ public class ConnectionPool {

}

public static synchronized void removeConnection(Long datasourceId) {
CONNECTION_MAP.forEach((k, v) -> {
if (k.contains(String.valueOf(datasourceId))) {
try {
Connection connection = v.getConnection();
if (connection != null) {
connection.close();
CONNECTION_MAP.remove(k);
}
} catch (SQLException e) {
log.error("close connection error", e);
}
}
});
}


public static Connection getConnection(ConnectInfo connectInfo) {
Connection connection = connectInfo.getConnection();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package ai.chat2db.spi.sql;

import java.util.LinkedHashMap;
import java.util.Map;

import org.apache.commons.lang3.ClassUtils;

/**
* @author luojun
* @version 1.0
* @description: 接口定义
* @date 2024/5/31 19:05
**/
public class DocumentUtils {

public static LinkedHashMap<String, Object> convertToMap(Object obj) {
LinkedHashMap<String, Object> map = new LinkedHashMap<>();
if (obj == null) {
return map;
}
if (ClassUtils.isPrimitiveOrWrapper(obj.getClass()) || String.class.equals(obj.getClass())) {
map.put("result", obj);
return map;
}
for (Map.Entry<String, Object> entry : ((Map<String, Object>) obj).entrySet()) {
Object value = entry.getValue();
if (value == null) {
map.put(entry.getKey(), null);
} else if (ClassUtils.isPrimitiveOrWrapper(value.getClass()) || String.class.equals(value.getClass())) {
map.put(entry.getKey(), value);
} else if (entry.getValue() instanceof Map) {
LinkedHashMap<String, Object> mmp = convertToMap(entry.getValue());
map.put(entry.getKey(), mmp);
} else {
map.put(entry.getKey(), entry.getValue().toString());
}
}
return map;
}
}
Loading

0 comments on commit 640ab10

Please sign in to comment.