Skip to main content

Text-to-SQL

Summary

This article describes a Text-to-SQL implementation using Spring AI. Given a user query, the actual query results will be returned by providing a tool for LLM to execute SQL statements.

The complete source is available on GitHub JavaAIDev/simple-text-to-sql.

Text-to-SQL is a typical usage of using AI to improve productivity. By using Text-to-SQL, non-technical people use natural language to describe database query requirements. These queries are sent to LLM. LLM can generate SQL statements to answer user queries. LLM can also use tools to execute SQL statements, and return the query results to the user. Text-to-SQL is a good example of AI applications.

The easiest way to implement text to sql is leveraging the capability of LLM directly. Modern LLMs are very good at generating code, including SQL statements. When generating SQL statements to answer user queries, LLM requires information related to the database, including metadata of tables and columns in the tables.

Prerequisites

Before running the Text-to-SQL application, you should have:

  • Java 21
  • A running Postgres server with sample table data loaded. If Docker Compose file in the repo is used, sample data will be loaded automatically.
  • An OpenAI API key set as the environment variable OPENAI_API_KEY.

Database Metadata

JDBC provides an API to get metadata from a database. We can use this API to extract metadata from a database.

In the implementation, I use Java record types to define different types of metadata.

package com.javaaidev.text2sql.metadata;

import java.util.List;

public record DatabaseMetadata(List<TableInfo> tables) {

}

DatabaseMetadataHelper shown below uses JDBC API to extract database metadata as a DatabaseMetadata object. Extract DatabaseMetadata object is then serialized as a JSON string using Jackson. This JSON string will be the metadata sent to LLM.

DatabaseMetadataHelper to extract database metadata
package com.javaaidev.text2sql.metadata;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Objects;
import javax.sql.DataSource;

public class DatabaseMetadataHelper {

private final DataSource dataSource;
private final ObjectMapper objectMapper;

public DatabaseMetadataHelper(DataSource dataSource,
ObjectMapper objectMapper) {
this.dataSource = dataSource;
this.objectMapper = objectMapper;
}

public String extractMetadataJson() throws SQLException {
var metadata = extractMetadata();
try {
return objectMapper.writeValueAsString(metadata);
} catch (JsonProcessingException e) {
return Objects.toString(metadata);
}
}

public DatabaseMetadata extractMetadata() throws SQLException {
var metadata = dataSource.getConnection().getMetaData();
var tablesInfo = new ArrayList<TableInfo>();
try (var tables = metadata.getTables(null, null, null,
new String[]{"TABLE"})) {
while (tables.next()) {
var tableName = tables.getString("TABLE_NAME");
var tableDescription = tables.getString("REMARKS");
var tableCatalog = tables.getString("TABLE_CAT");
var tableSchema = tables.getString("TABLE_SCHEM");
var columnsInfo = new ArrayList<ColumnInfo>();
try (var columns = metadata.getColumns(null, null, tableName, null)) {
while (columns.next()) {
var columnName = columns.getString("COLUMN_NAME");
var datatype = columns.getString("TYPE_NAME");
var columnDescription = columns.getString("REMARKS");
columnsInfo.add(
new ColumnInfo(columnName, datatype, columnDescription)
);
}
}
tablesInfo.add(
new TableInfo(
tableName,
tableDescription,
tableCatalog,
tableSchema,
columnsInfo
)
);
}
}
return new DatabaseMetadata(tablesInfo);
}
}

Use Advisor

Extracted database metadata JSON string is set as the system text of the prompt sent to LLM. A Spring AI advisor is used to set the system text.

DatabaseMetadataAdvisor is an implementation of CallAroundAdvisor. Before the request is sent to LLM, the request is updated to include the system text prompt template and variables. The prompt template of system text has a variable table_schemas. In the runtime, value of this variable will be replaced with the database metadata JSON string extracted using DatabaseMetadataHelper.

package com.javaaidev.text2sql;

import com.javaaidev.text2sql.metadata.DatabaseMetadataHelper;
import java.sql.SQLException;
import java.util.HashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.advisor.api.AdvisedRequest;
import org.springframework.ai.chat.client.advisor.api.AdvisedResponse;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisor;
import org.springframework.ai.chat.client.advisor.api.CallAroundAdvisorChain;
import org.springframework.core.Ordered;

/**
* Advisor to update system text of the prompt
*/
public class DatabaseMetadataAdvisor implements CallAroundAdvisor {

private static final String DEFAULT_SYSTEM_TEXT = """
You are a Postgres expert. Please help to generate a Postgres query, then run the query to answer the question. The output should be in tabular format.

===Tables
{table_schemas}
""";

private final DatabaseMetadataHelper databaseMetadataHelper;
private final String tableSchemas;
private static final Logger LOGGER = LoggerFactory.getLogger(
DatabaseMetadataAdvisor.class);

public DatabaseMetadataAdvisor(
DatabaseMetadataHelper databaseMetadataHelper) {
this.databaseMetadataHelper = databaseMetadataHelper;
this.tableSchemas = getDatabaseMetadata();
LOGGER.info("Loaded database metadata: {}", this.tableSchemas);
}

private String getDatabaseMetadata() {
try {
return databaseMetadataHelper.extractMetadataJson();
} catch (SQLException e) {
LOGGER.error("Failed to load database metadata", e);
throw new RuntimeException(e);
}
}

@Override
public AdvisedResponse aroundCall(AdvisedRequest advisedRequest,
CallAroundAdvisorChain chain) {
var systemParams = new HashMap<>(advisedRequest.systemParams());
systemParams.put("table_schemas", tableSchemas);
var request = AdvisedRequest.from(advisedRequest)
.systemText(DEFAULT_SYSTEM_TEXT)
.systemParams(systemParams)
.build();
return chain.nextAroundCall(request);
}

@Override
public String getName() {
return getClass().getSimpleName();
}

@Override
public int getOrder() {
return Ordered.HIGHEST_PRECEDENCE;
}
}

Execute SQL Statements

So far, we only ask the LLM to output the SQL statements. However, end users may not care about the SQL statements. They want to see the actual query results. After LLM generates SQL statements, these SQL statements still need to be executed manually in database clients to get the actual results. It’s better that LLM can output the query results instead of SQL statements.

To get the query results, generated SQL statements need to be executed against the real database. LLM itself doesn’t have the capability to execute SQL statements. LLM requires an external tool to perform this action. This tool can execute SQL statements and return results.

Spring AI provides support for calling external tools. All we need to do is to declare java Functions.

RunSqlQueryTool implements Function<RunSqlQueryRequest, RunSqlQueryResponse> interface. The tool’s input type is RunSqlQueryRequest. It defines a single field query which represents the SQL statement to execute. The tool’s output type is RunSqlQueryResponse. It defines two fields, success and error, which represent successful results and error messages, respectively.

In the implementation of RunSqlQueryTool, I use JdbcClient of Spring JDBC to execute SQL statements and get the results. The execution result type is List<Map<String, Object>>. Each Map in the List represents a row in the ResultSet. From the Map, we can use column name in the ResultSet to get the value of a column. The execution result is in CSV format. The first row in the CSV is column names. If execution failed, message of the exception is used as the error result.

package com.javaaidev.text2sql.tool;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.util.CollectionUtils;

/**
* Use {@linkplain JdbcClient} to run SQL query and output result in CSV format
*/
public class RunSqlQueryTool implements
Function<RunSqlQueryRequest, RunSqlQueryResponse> {

private final JdbcClient jdbcClient;

private static final Logger LOGGER = LoggerFactory.getLogger(
RunSqlQueryTool.class);

public RunSqlQueryTool(JdbcClient jdbcClient) {
this.jdbcClient = jdbcClient;
}

@Override
public RunSqlQueryResponse apply(RunSqlQueryRequest request) {
try {
LOGGER.info("SQL query to run [{}]", request.query());
return new RunSqlQueryResponse(runQuery(request.query()), null);
} catch (Exception e) {
return new RunSqlQueryResponse(null, e.getMessage());
}
}

private String runQuery(String query) {
var rows = jdbcClient.sql(query)
.query().listOfRows();
if (CollectionUtils.isEmpty(rows)) {
return "";
}
var fields = rows.getFirst().keySet().stream().sorted().toList();
var printer = CSVFormat.DEFAULT.builder()
.setHeader(fields.toArray(new String[0]))
.setSkipHeaderRecord(false)
.setRecordSeparator('\n')
.build();
var builder = new StringBuilder();
for (Map<String, Object> row : rows) {
try {
printer.printRecord(builder, fields.stream().map(row::get).toArray());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return builder.toString();
}
}

To register RunSqlQueryTool as a function, we should define a bean of type RunSqlQueryTool. Bean name will be the tool’s name. Use the @Description annotation to add tool’s description. From the generic type definition of RunSqlQueryTool, Spring AI can detect the input and output type. From the Java type definition, Spring AI builds JSON schema to describe types. These information will be passed to LLM in the API request.

Register Spring AI function as bean
@Configuration
public class AppConfiguration {
@Bean
@Description("Query database using SQL")
public RunSqlQueryTool runSqlQuery(JdbcClient jdbcClient) {
return new RunSqlQueryTool(jdbcClient);
}
}

REST Controller

Now we can build the REST API to implement Text-to-SQL. ChatClient.Builder is used to build ChatClients. DatabaseMetadataAdvisor is added as a default advisor for the ChatClient.

Here I use OpenAI GPT-4o mini as the model. runSqlQuery is included as a function name in the request, so LLM can use this tool to execute generated SQL statements.

package com.javaaidev.text2sql.controller;

import com.javaaidev.text2sql.DatabaseMetadataAdvisor;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;

@RestController
public class ChatController {

private final ChatClient chatClient;

public ChatController(ChatClient.Builder builder,
DatabaseMetadataAdvisor databaseMetadataAdvisor) {
chatClient = builder.defaultAdvisors(databaseMetadataAdvisor).build();
}

@PostMapping("/chat")
public ChatResponse chat(@RequestBody ChatRequest request) {
return new ChatResponse(
chatClient.prompt().user(request.input())
.options(OpenAiChatOptions.builder()
.model(ChatModel.GPT_4_O_MINI)
.temperature(0.0)
.function("runSqlQuery")
.build())
.call().content());
}
}

Test

Now we can test the REST API. Start the server and use Swagger UI to run query.

Sample query:

how many movies are produced in the United States?

Output:

There are 2,058 movies produced in the United States.
Complete Text-to-SQL Course

Text-to-SQL implementation described here doesn't use RAG to find table metadata relevant to the input query. Check out my Text-to-SQL course.

Text-to-SQL Book

If you prefer a text version, check out my Text-to-SQL book.