Text-to-SQL
This article describes a Text-to-SQL implementation using Spring AI. Given a user query, the actual query results will be returned by providing a tool for LLM to execute SQL statements.
The complete source is available on GitHub JavaAIDev/simple-text-to-sql.
Text-to-SQL is a typical usage of using AI to improve productivity. By using Text-to-SQL, non-technical people use natural language to describe database query requirements. These queries are sent to LLM. LLM can generate SQL statements to answer user queries. LLM can also use tools to execute SQL statements, and return the query results to the user. Text-to-SQL is a good example of AI applications.
The easiest way to implement text to sql is leveraging the capability of LLM directly. Modern LLMs are very good at generating code, including SQL statements. When generating SQL statements to answer user queries, LLM requires information related to the database, including metadata of tables and columns in the tables.
Prerequisites
Before running the Text-to-SQL application, you should have:
- Java 21
- A running Postgres server with sample table data loaded. If Docker Compose file in the repo is used, sample data will be loaded automatically.
- An OpenAI API key set as the environment variable
OPENAI_API_KEY
.
Database Metadata
JDBC provides an API to get metadata from a database. We can use this API to extract metadata from a database.
In the implementation, I use Java record types to define different types of metadata.
- DatabaseMetadata
- TableInfo
- ColumnInfo
package com.javaaidev.text2sql.metadata;
import java.util.List;
public record DatabaseMetadata(List<TableInfo> tables) {
}
package com.javaaidev.text2sql.metadata;
import java.util.List;
public record TableInfo(String name, String description, String catalog,
String schema, List<ColumnInfo> columns) {
}
package com.javaaidev.text2sql.metadata;
public record ColumnInfo(String name, String dataType, String description) {
}
DatabaseMetadataHelper
shown below uses JDBC API to extract database metadata as a DatabaseMetadata
object. Extract
DatabaseMetadata
object is then serialized as a JSON string using Jackson. This JSON string will be the metadata sent
to LLM.
package com.javaaidev.text2sql.metadata;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Objects;
import javax.sql.DataSource;
public class DatabaseMetadataHelper {
private final DataSource dataSource;
private final ObjectMapper objectMapper;
public DatabaseMetadataHelper(DataSource dataSource,
ObjectMapper objectMapper) {
this.dataSource = dataSource;
this.objectMapper = objectMapper;
}
public String extractMetadataJson() throws SQLException {
var metadata = extractMetadata();
try {
return objectMapper.writeValueAsString(metadata);
} catch (JsonProcessingException e) {
return Objects.toString(metadata);
}
}
public DatabaseMetadata extractMetadata() throws SQLException {
var metadata = dataSource.getConnection().getMetaData();
var tablesInfo = new ArrayList<TableInfo>();
try (var tables = metadata.getTables(null, null, null,
new String[]{"TABLE"})) {
while (tables.next()) {
var tableName = tables.getString("TABLE_NAME");
var tableDescription = tables.getString("REMARKS");
var tableCatalog = tables.getString("TABLE_CAT");
var tableSchema = tables.getString("TABLE_SCHEM");
var columnsInfo = new ArrayList<ColumnInfo>();
try (var columns = metadata.getColumns(null, null, tableName, null)) {
while (columns.next()) {
var columnName = columns.getString("COLUMN_NAME");
var datatype = columns.getString("TYPE_NAME");
var columnDescription = columns.getString("REMARKS");
columnsInfo.add(
new ColumnInfo(columnName, datatype, columnDescription)
);
}
}
tablesInfo.add(
new TableInfo(
tableName,
tableDescription,
tableCatalog,
tableSchema,
columnsInfo
)
);
}
}
return new DatabaseMetadata(tablesInfo);
}
}
Use Advisor
Extracted database metadata JSON string is set as the system text of the prompt sent to LLM. A Spring AI advisor is used to set the system text.
DatabaseMetadataAdvisor
is an implementation of CallAroundAdvisor
. Before the request is sent to LLM, the request is
updated to include the system text prompt template and variables. The prompt template of system text has a variable
table_schemas
. In the runtime, value of this variable will be replaced with the database metadata JSON string
extracted using DatabaseMetadataHelper
.
package com.javaaidev.text2sql;
import com.javaaidev.text2sql.metadata.DatabaseMetadataHelper;
import java.sql.SQLException;
import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.core.Ordered;
/**
* Advisor to update system text of the prompt
*/
public class DatabaseMetadataAdvisor implements CallAroundAdvisor {
private static final String DEFAULT_SYSTEM_TEXT = """
You are a Postgres expert. Please help to generate a Postgres query, then run the query to answer the question. The output should be in tabular format.
===Tables
{table_schemas}
""";
private final DatabaseMetadataHelper databaseMetadataHelper;
private final String tableSchemas;
private static final Logger LOGGER = LoggerFactory.getLogger(
DatabaseMetadataAdvisor.class);
public DatabaseMetadataAdvisor(
DatabaseMetadataHelper databaseMetadataHelper) {
this.databaseMetadataHelper = databaseMetadataHelper;
this.tableSchemas = getDatabaseMetadata();
LOGGER.info("Loaded database metadata: {}", this.tableSchemas);
}
private String getDatabaseMetadata() {
try {
return databaseMetadataHelper.extractMetadataJson();
} catch (SQLException e) {
LOGGER.error("Failed to load database metadata", e);
throw new RuntimeException(e);
}
}
@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest,
CallAroundAdvisorChain chain) {
var systemParams = new HashMap<>(advisedRequest.systemParams());
systemParams.put("table_schemas", tableSchemas);
var request = AdvisedRequest.from(advisedRequest)
.systemText(DEFAULT_SYSTEM_TEXT)
.systemParams(systemParams)
.build();
return chain.nextAroundCall(request);
}
@Override
public String getName() {
return getClass().getSimpleName();
}
@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
}
Execute SQL Statements
So far, we only ask the LLM to output the SQL statements. However, end users may not care about the SQL statements. They want to see the actual query results. After LLM generates SQL statements, these SQL statements still need to be executed manually in database clients to get the actual results. It’s better that LLM can output the query results instead of SQL statements.
To get the query results, generated SQL statements need to be executed against the real database. LLM itself doesn’t have the capability to execute SQL statements. LLM requires an external tool to perform this action. This tool can execute SQL statements and return results.
Spring AI provides support for calling external tools. All we need to do is to declare java Function
s.
RunSqlQueryTool
implements Function<RunSqlQueryRequest, RunSqlQueryResponse>
interface. The tool’s input type is
RunSqlQueryRequest
. It defines a single field query
which represents the SQL statement to execute. The tool’s output
type is RunSqlQueryResponse
. It defines two fields, success
and error
, which represent successful results and
error messages, respectively.
In the implementation of RunSqlQueryTool
, I use JdbcClient
of Spring JDBC to execute SQL statements and get the
results. The execution result type is List<Map<String, Object>>
. Each Map
in the List
represents a row in the
ResultSet
. From the Map
, we can use column name in the ResultSet
to get the value of a column. The execution
result is in CSV format. The first row in the CSV is column names. If execution failed, message of the exception is used
as the error result.
- RunSqlQueryTool
- RunSqlQueryRequest
- RunSqlQueryResponse
package com.javaaidev.text2sql.tool;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.util.CollectionUtils;
/**
* Use {@linkplain JdbcClient} to run SQL query and output result in CSV format
*/
public class RunSqlQueryTool implements
Function<RunSqlQueryRequest, RunSqlQueryResponse> {
private final JdbcClient jdbcClient;
private static final Logger LOGGER = LoggerFactory.getLogger(
RunSqlQueryTool.class);
public RunSqlQueryTool(JdbcClient jdbcClient) {
this.jdbcClient = jdbcClient;
}
@Override
public RunSqlQueryResponse apply(RunSqlQueryRequest request) {
try {
LOGGER.info("SQL query to run [{}]", request.query());
return new RunSqlQueryResponse(runQuery(request.query()), null);
} catch (Exception e) {
return new RunSqlQueryResponse(null, e.getMessage());
}
}
private String runQuery(String query) {
var rows = jdbcClient.sql(query)
.query().listOfRows();
if (CollectionUtils.isEmpty(rows)) {
return "";
}
var fields = rows.getFirst().keySet().stream().sorted().toList();
var printer = CSVFormat.DEFAULT.builder()
.setHeader(fields.toArray(new String[0]))
.setSkipHeaderRecord(false)
.setRecordSeparator('\n')
.build();
var builder = new StringBuilder();
for (Map<String, Object> row : rows) {
try {
printer.printRecord(builder, fields.stream().map(row::get).toArray());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return builder.toString();
}
}
package com.javaaidev.text2sql.tool;
/**
* Request to run SQL query
*
* @param query SQL query
*/
public record RunSqlQueryRequest(String query) {
}
package com.javaaidev.text2sql.tool;
/**
* Response of SQL query
*
* @param result Success result
* @param error Error result
*/
public record RunSqlQueryResponse(String result, String error) {
}
To register RunSqlQueryTool
as a function, we should define a bean of type RunSqlQueryTool
. Bean name will be the
tool’s name. Use the @Description
annotation to add tool’s description. From the generic type definition of
RunSqlQueryTool
, Spring AI can detect the input and output type. From the Java type definition, Spring AI builds JSON
schema to describe types. These information will be passed to LLM in the API request.
@Configuration
public class AppConfiguration {
@Bean
@Description("Query database using SQL")
public RunSqlQueryTool runSqlQuery(JdbcClient jdbcClient) {
return new RunSqlQueryTool(jdbcClient);
}
}
REST Controller
Now we can build the REST API to implement Text-to-SQL. ChatClient.Builder
is used to build ChatClient
s.
DatabaseMetadataAdvisor
is added as a default advisor for the ChatClient
.
Here I use OpenAI GPT-4o mini as the model. runSqlQuery
is included as a function name in the request, so LLM can
use this tool to execute generated SQL statements.
package com.javaaidev.text2sql.controller;
import com.javaaidev.text2sql.DatabaseMetadataAdvisor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class ChatController {
private final ChatClient chatClient;
public ChatController(ChatClient.Builder builder,
DatabaseMetadataAdvisor databaseMetadataAdvisor) {
chatClient = builder.defaultAdvisors(databaseMetadataAdvisor).build();
}
@PostMapping("/chat")
public ChatResponse chat(@RequestBody ChatRequest request) {
return new ChatResponse(
chatClient.prompt().user(request.input())
.options(OpenAiChatOptions.builder()
.model(ChatModel.GPT_4_O_MINI)
.temperature(0.0)
.function("runSqlQuery")
.build())
.call().content());
}
}
Test
Now we can test the REST API. Start the server and use Swagger UI to run query.
Sample query:
how many movies are produced in the United States?
Output:
There are 2,058 movies produced in the United States.
Text-to-SQL implementation described here doesn't use RAG to find table metadata relevant to the input query. Check out my Text-to-SQL course.
If you prefer a text version, check out my Text-to-SQL book.