ESP32S3的CHAT GPT搭建对接:模型3.5turbo(串口聊天)

效果图

alt text

开发板选用了ESP32S3 R16N8版本 可自行尝试其它开发板 编译器 Arduino IDE

1. 密钥以及WiFi配置

// WiFi网络名称和密码
const char* ssid = "你的WiFi名称";     // WiFi网络名称
const char* password = "你的WiFi密码"; // WiFi密码

const char* gptAPIEndpoint = "https://api.openai-hk.com/v1/chat/completions"; // GPT API端点
const char* apiKey = "gpt api密钥"; // 你的OpenAI API密钥
// GPT模型版本	gpt-4o gpt-4-turbo-2024-04-09 gpt-4-0125-preview gpt-4 gpt-4-1106-preview gpt-4-vision-preview gpt-3.5 gpt-3.5-turbo-1106 gpt-3.5-16k gpt-4-all
const char* gptModel = "gpt-3.5-turbo"; 

2. 上下文,最大token数量,以及最大结构体

const int MAX_CONTEXT_ENTRIES = 20;//上下文共20个
const int MAX_MESSAGE_LENGTH = 4096;//消息最大长度4096,是每一条消息的最大长度
const int MAX_JSON_SIZE = 8192;//json最大长度8192;也就是json结构体最大长度
String contextMessages[MAX_CONTEXT_ENTRIES * 2];//上下文消息
int contextIndex = 0;//用于记录上下文消息的索引

3.枚举

enum State {
  WAITING_FOR_INPUT,//等待输入
  SENDING_REQUEST,//发送请求
  WAITING_FOR_RESPONSE//等待响应
};
State currentState = WAITING_FOR_INPUT;//当前状态,用于记录是否发布消息和接收消息
String userMessage;//用户输入的消息

4.初始化部分

4.1 使用更大空间的存储,以保证能存储足够的消息,更改了Flash Size:16MB;Partition Scheme: Huge APP(3MB No OTA/1MB SPIFES);PSRAM:QSPI;

alt text 示例代码如下

// 检查PSRAM是否启用
  if (psramFound()) {
    Serial.println("PSRAM found and initialized.");
  } else {
    Serial.println("No PSRAM found.");
  }
    // 初始化SPIFFS
  if (!SPIFFS.begin(true)) {
    Serial.println("An Error has occurred while mounting SPIFFS");
    return;
  }

5.WiFi连接

// 连接到WiFi网络
  Serial.println();
  Serial.print("Connecting to ");
  Serial.println(ssid);
  WiFi.begin(ssid, password);
  while (WiFi.status() != WL_CONNECTED) {
    delay(500);
    Serial.print(".");
  }
  Serial.println("");
  Serial.println("WiFi connected");

  Serial.println("Enter your message:");

6.字符转义

String sanitizeMessage(const String& message) {//为了数据能正确的嵌入到json中,需要将特殊字符转义
  String sanitized = message;//这里将输入字符串message复制到变量sanitized,以便对其进行修改。
  sanitized.replace("\\", "\\\\");//在JSON字符串中,反斜杠是一个转义字符。如果字符串中包含反斜杠,它需要被转义为双反斜杠\\,以避免解析错误。
  sanitized.replace("\"", "\\\"");
  sanitized.replace("\n", "\\n");
  sanitized.replace("\r", "\\r");
  return sanitized;
}

7.管理上下文

void addMessageToContext(const String& role, String content) {
  if (content.length() > MAX_MESSAGE_LENGTH) {//如果消息长度大于4096,截取前4096个字符
    content = content.substring(0, MAX_MESSAGE_LENGTH);
  }
//把role和content拼接成json格式,会进行转义,结果示例:{"role": "user", "content": "你好"}
  String message = "{\"role\": \"" + role + "\", \"content\": \"" + sanitizeMessage(content) + "\"}";


  if (contextIndex >= MAX_CONTEXT_ENTRIES * 2) {//检测是否超过上下文数组的大小   乘以2是因为在和gpt交互时,用户输入的消息和gpt的回复都会被存储,所以是2条消息
    for (int i = 0; i < (MAX_CONTEXT_ENTRIES - 1) * 2; i++) {//如果超过就会删掉最顶上的两条旧信息。确保新消息能存储
      contextMessages[i] = contextMessages[i + 2];
    }
    contextIndex = (MAX_CONTEXT_ENTRIES - 1) * 2;
  }

  contextMessages[contextIndex++] = message;
}

8.内容存储

//主要是存储上下文信息,以便下次使用
void saveContextToSPIFFS() {//将上下文信息存储到SPIFFS中
  File file = SPIFFS.open("/context.txt", FILE_WRITE);//写入模式打开文件,如果没有文件则自动新建一个
  if (!file) {//检测是否打开成功
    Serial.println("Failed to open file for writing");
    return;
  }

  for (int i = 0; i < contextIndex; i++) {//辩论上下文消息,contextIndex是上下文消息的索引,所以是上下文消息的个数
    file.println(contextMessages[i]);//将每条消息写入文件。println方法会在每条消息后添加一个换行符,使每条消息占一行,便于以后读取和解析。
  }

  file.close();//关闭文件
}

9.内容读取

//加载对话上下文,并将其存储到上下文数组中
void loadContextFromSPIFFS() {
  File file = SPIFFS.open("/context.txt", FILE_READ);//读取模式打开
  if (!file) {//打不开就你懂的
    Serial.println("Failed to open file for reading");
    return;
  }

  contextIndex = 0;//从头开始读喽
  while (file.available() && contextIndex < MAX_CONTEXT_ENTRIES * 2) {//遍历文件的每一行,直到文件结束或上下文数组已满
    contextMessages[contextIndex++] = file.readStringUntil('\n');//读取文件的一行,读到换行符就算一行读取结束,并将其存储到上下文数组中,然后递增上下文数组的索引
  }

  file.close();//这里,你懂的
}

10.请求处理

//把json发送到api
bool sendRequestWithRetries(String jsonRequest, int retries = 3) {
  HTTPClient http;//创建http客户端对象
  http.begin(gptAPIEndpoint);//初始化http客户端 设置请求地址
  http.addHeader("Content-Type", "application/json");//请求头+知道内容类型为json
  http.addHeader("Authorization", "Bearer " + String(apiKey));//添加请求头,设置授权信息,使用Bearer Token进行身份验证。
  http.setTimeout(20000); // 20秒超时时间
  //重试机制和发送请求
  int attempt = 0;//重试次数
  int httpResponseCode = -1;//http响应码,-1表示没有响应

  while (attempt < retries && httpResponseCode <= 0) {//开始重试循环,只要重试次数未超过retries且响应码为负(表示请求失败)就继续重试。
    attempt++;//增加重试计数器。
    httpResponseCode = http.POST(jsonRequest);//发送HTTP POST请求并获取响应码

    if (httpResponseCode > 0) {//检测响应码是否为正数
      String response = http.getString();//获取响应内容
      Serial.println("Response:");
      Serial.println(response);//打印响应内容到串口。

      // 解析JSON响应
      DynamicJsonDocument jsonBuffer(32768); // 增大DynamicJsonDocument大小,防止内存溢出
      DeserializationError error = deserializeJson(jsonBuffer, response);//解析json结构体

      if (error) {//检测是否解析成功
        Serial.print("deserializeJson() failed: ");
        Serial.println(error.f_str());//打印错误信息
        return false;
      }

      // 获取生成的文本
      if (jsonBuffer.containsKey("choices")) {//检测是否包含choices字段
        String generatedText = jsonBuffer["choices"][0]["message"]["content"].as<String>();//获取choices字段中的message字段中的content字段

        Serial.println("Generated Text:");
        Serial.println(generatedText);//打印生成的文本

        // 添加生成的文本到上下文数组
        addMessageToContext("assistant", generatedText);
        saveContextToSPIFFS(); // 保存上下文到SPIFFS

        return true;//返回true
      } else {//没有choices字段
        Serial.println("Error: No 'choices' field in response");
        return false;
      }
    } else {//没有响应
      Serial.print("Error on HTTP request (attempt ");
      Serial.print(attempt);
      Serial.print("): ");
      Serial.println(httpResponseCode);
      delay(2000); // 等待2秒后重试
    }
  }

  http.end();//关闭http客户端,释放内存
  return false;//全失败则表示重试次数已用完,返回false
}

11.构建JSON请求

void sendRequest() {
  // 构建JSON请求
  //存储请求体。请求体包含模型的参数(如最大token数、模型版本、惊喜度、top_p和presence_penalty)以及消息数组的起始部分。
  String jsonRequest = "{\"max_tokens\": 4096, \"model\": \"" + String(gptModel) + "\", \"temperature\": 0.8, \"top_p\": 1, \"presence_penalty\": 1, \"messages\": [";

  for (int i = 0; contextIndex > 0 && i < contextIndex; i++) {//遍历上下文数组
    jsonRequest += contextMessages[i] + ",";//将上下文消息添加到请求体中
  }

  if (contextIndex > 0) {//如果上下文数组不为空
    jsonRequest.remove(jsonRequest.length() - 1);//删除最后一个逗号
  }
  jsonRequest += "]}";//添加结束符

  if (jsonRequest.length() > MAX_JSON_SIZE) {//检测请求体是否超过最大长度
    Serial.println("Error: JSON request size exceeds the limit");//打印错误信息
    currentState = WAITING_FOR_INPUT;//将状态设置为等待输入
    return;
  }

  Serial.println("JSON Request:");//打印请求体
  Serial.println(jsonRequest);

  if (!sendRequestWithRetries(jsonRequest)) {//如果请求发送失败并在重试多次后仍然失败,打印错误信息。如果请求成功,更新程序状态为 WAITING_FOR_INPUT,表示等待新的用户输入,并打印提示信息。
    Serial.println("Failed to send request after multiple attempts");
  } else {
    currentState = WAITING_FOR_INPUT;
    Serial.println("Enter your message:");
  }
}

12.主循环

void loop() {
  switch (currentState) {//根据状态执行相应的操作
    case WAITING_FOR_INPUT://状态 等待输入
      if (WiFi.status() == WL_CONNECTED) {//判断wifi是否连接
        if (Serial.available()) {//检测串口是否有数据
          userMessage = Serial.readStringUntil('\n');//读取串口数据,存储到userMessage变量中
          addMessageToContext("user", userMessage);//将用户输入的消息添加到上下文数组中
          saveContextToSPIFFS();//保存上下文到SPIFFS
          currentState = SENDING_REQUEST;//将状态设置为发送请求
        }
      } else {//wifi未连接
        Serial.println("WiFi Disconnected");
      }
      break;

    case SENDING_REQUEST://状态 发送请求
      sendRequest();//发送请求
      break;

    case WAITING_FOR_RESPONSE:
      // 等待响应处理完成
      break;
  }
}

源码

#include <WiFi.h>
#include <HTTPClient.h>
#include <ArduinoJson.h>
#include <SPIFFS.h>

// WiFi网络名称和密码
const char* ssid = "WiFi网络名称";     // WiFi网络名称
const char* password = "WiFi网络密码"; // WiFi密码

const char* gptAPIEndpoint = "https://api.openai-hk.com/v1/chat/completions"; // GPT API端点
const char* apiKey = "密钥"; // 你的OpenAI API密钥
const char* gptModel = "gpt-3.5-turbo"; // GPT模型版本	

const int MAX_CONTEXT_ENTRIES = 20;
const int MAX_MESSAGE_LENGTH = 4096;
const int MAX_JSON_SIZE = 8192;
String contextMessages[MAX_CONTEXT_ENTRIES * 2];
int contextIndex = 0;

enum State {
  WAITING_FOR_INPUT,
  SENDING_REQUEST,
  WAITING_FOR_RESPONSE
};

State currentState = WAITING_FOR_INPUT;
String userMessage;

void setup() {
  Serial.begin(115200);
  delay(100);

  // 检查PSRAM是否启用
  if (psramFound()) {
    Serial.println("PSRAM found and initialized.");
  } else {
    Serial.println("No PSRAM found.");
  }

  // 初始化SPIFFS
  if (!SPIFFS.begin(true)) {
    Serial.println("An Error has occurred while mounting SPIFFS");
    return;
  }

  // 连接到WiFi网络
  Serial.println();
  Serial.print("Connecting to ");
  Serial.println(ssid);
  WiFi.begin(ssid, password);
  while (WiFi.status() != WL_CONNECTED) {
    delay(500);
    Serial.print(".");
  }
  Serial.println("");
  Serial.println("WiFi connected");

  Serial.println("Enter your message:");
}

String sanitizeMessage(const String& message) {
  String sanitized = message;
  sanitized.replace("\\", "\\\\");
  sanitized.replace("\"", "\\\"");
  sanitized.replace("\n", "\\n");
  sanitized.replace("\r", "\\r");
  return sanitized;
}

void addMessageToContext(const String& role, String content) {
  if (content.length() > MAX_MESSAGE_LENGTH) {
    content = content.substring(0, MAX_MESSAGE_LENGTH);
  }

  String message = "{\"role\": \"" + role + "\", \"content\": \"" + sanitizeMessage(content) + "\"}";

  if (contextIndex >= MAX_CONTEXT_ENTRIES * 2) {
    for (int i = 0; i < (MAX_CONTEXT_ENTRIES - 1) * 2; i++) {
      contextMessages[i] = contextMessages[i + 2];
    }
    contextIndex = (MAX_CONTEXT_ENTRIES - 1) * 2;
  }

  contextMessages[contextIndex++] = message;
}

void saveContextToSPIFFS() {
  File file = SPIFFS.open("/context.txt", FILE_WRITE);
  if (!file) {
    Serial.println("Failed to open file for writing");
    return;
  }

  for (int i = 0; i < contextIndex; i++) {
    file.println(contextMessages[i]);
  }

  file.close();
}

void loadContextFromSPIFFS() {
  File file = SPIFFS.open("/context.txt", FILE_READ);
  if (!file) {
    Serial.println("Failed to open file for reading");
    return;
  }

  contextIndex = 0;
  while (file.available() && contextIndex < MAX_CONTEXT_ENTRIES * 2) {
    contextMessages[contextIndex++] = file.readStringUntil('\n');
  }

  file.close();
}

bool sendRequestWithRetries(String jsonRequest, int retries = 3) {
  HTTPClient http;
  http.begin(gptAPIEndpoint);
  http.addHeader("Content-Type", "application/json");
  http.addHeader("Authorization", "Bearer " + String(apiKey));
  http.setTimeout(20000); // 20秒超时时间

  int attempt = 0;
  int httpResponseCode = -1;

  while (attempt < retries && httpResponseCode <= 0) {
    attempt++;
    httpResponseCode = http.POST(jsonRequest);

    if (httpResponseCode > 0) {
      String response = http.getString();
      Serial.println("Response:");
      Serial.println(response);

      // 解析JSON响应
      DynamicJsonDocument jsonBuffer(32768); // 增大DynamicJsonDocument大小
      DeserializationError error = deserializeJson(jsonBuffer, response);

      if (error) {
        Serial.print("deserializeJson() failed: ");
        Serial.println(error.f_str());
        return false;
      }

      // 获取生成的文本
      if (jsonBuffer.containsKey("choices")) {
        String generatedText = jsonBuffer["choices"][0]["message"]["content"].as<String>();

        Serial.println("Generated Text:");
        Serial.println(generatedText);

        // 添加生成的文本到上下文数组
        addMessageToContext("assistant", generatedText);
        saveContextToSPIFFS(); // 保存上下文到SPIFFS

        return true;
      } else {
        Serial.println("Error: No 'choices' field in response");
        return false;
      }
    } else {
      Serial.print("Error on HTTP request (attempt ");
      Serial.print(attempt);
      Serial.print("): ");
      Serial.println(httpResponseCode);
      delay(2000); // 等待2秒后重试
    }
  }

  http.end();
  return false;
}

void sendRequest() {
  // 构建JSON请求
  String jsonRequest = "{\"max_tokens\": 4096, \"model\": \"" + String(gptModel) + "\", \"temperature\": 0.8, \"top_p\": 1, \"presence_penalty\": 1, \"messages\": [";

  for (int i = 0; contextIndex > 0 && i < contextIndex; i++) {
    jsonRequest += contextMessages[i] + ",";
  }

  if (contextIndex > 0) {
    jsonRequest.remove(jsonRequest.length() - 1);
  }
  jsonRequest += "]}";

  if (jsonRequest.length() > MAX_JSON_SIZE) {
    Serial.println("Error: JSON request size exceeds the limit");
    currentState = WAITING_FOR_INPUT;
    return;
  }

  Serial.println("JSON Request:");
  Serial.println(jsonRequest);

  if (!sendRequestWithRetries(jsonRequest)) {
    Serial.println("Failed to send request after multiple attempts");
  } else {
    currentState = WAITING_FOR_INPUT;
    Serial.println("Enter your message:");
  }
}

void loop() {
  switch (currentState) {
    case WAITING_FOR_INPUT:
      if (WiFi.status() == WL_CONNECTED) {
        if (Serial.available()) {
          userMessage = Serial.readStringUntil('\n');
          addMessageToContext("user", userMessage);
          saveContextToSPIFFS();
          currentState = SENDING_REQUEST;
        }
      } else {
        Serial.println("WiFi Disconnected");
      }
      break;

    case SENDING_REQUEST:
      sendRequest();
      break;

    case WAITING_FOR_RESPONSE:
      // 等待响应处理完成
      break;
  }
}