> generateWriteContent(@RequestBody @Valid AiWriteGenerateReqVO generateReqVO) {
return writeService.generateWriteContent(generateReqVO, getLoginUserId());
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWritePageReqVO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWritePageReqVO.java
index 047380e422..04f99ae13c 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWritePageReqVO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/write/vo/AiWritePageReqVO.java
@@ -13,8 +13,6 @@ import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_
@Schema(description = "管理后台 - AI 写作分页 Request VO")
@Data
-@EqualsAndHashCode(callSuper = true)
-@ToString(callSuper = true)
public class AiWritePageReqVO extends PageParam {
@Schema(description = "用户编号", example = "28404")
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
index 7d9625f58f..23aec276db 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatConversationDO.java
@@ -1,9 +1,8 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@@ -22,7 +21,6 @@ import java.time.LocalDateTime;
@TableName("ai_chat_conversation")
@KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data
-@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
@@ -65,21 +63,16 @@ public class AiChatConversationDO extends BaseDO {
*/
private Long roleId;
- /**
- * 知识库编号
- *
- * 关联 {@link AiKnowledgeDO#getId()}
- */
- private Long knowledgeId;
-
/**
* 模型编号
*
- * 关联 {@link AiChatModelDO#getId()} 字段
+ * 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;
/**
* 模型标志
+ *
+ * 冗余 {@link AiModelDO#getModel()} 字段
*/
private String model;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
index ecd10609f5..2364d750cb 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/chat/AiChatMessageDO.java
@@ -1,14 +1,14 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.chat;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
-import com.baomidou.mybatisplus.extension.handlers.JacksonTypeHandler;
import lombok.*;
import org.springframework.ai.chat.messages.MessageType;
@@ -20,10 +20,9 @@ import java.util.List;
* @since 2024/4/14 17:35
* @since 2024/4/14 17:35
*/
-@TableName("ai_chat_message")
+@TableName(value = "ai_chat_message", autoResultMap = true)
@KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data
-@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
@@ -71,23 +70,16 @@ public class AiChatMessageDO extends BaseDO {
*/
private Long roleId;
-
- /**
- * 段落编号数组
- *
- * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
- */
- @TableField(typeHandler = JacksonTypeHandler.class)
- private List segmentIds;
-
/**
* 模型标志
+ *
+ * 冗余 {@link AiModelDO#getModel()}
*/
private String model;
/**
* 模型编号
*
- * 关联 {@link AiChatModelDO#getId()} 字段
+ * 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;
@@ -101,4 +93,12 @@ public class AiChatMessageDO extends BaseDO {
*/
private Boolean useContext;
+ /**
+ * 知识库段落编号数组
+ *
+ * 关联 {@link AiKnowledgeSegmentDO#getId()} 字段
+ */
+ @TableField(typeHandler = LongListTypeHandler.class)
+ private List segmentIds;
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java
index 56749a1d00..a18904c022 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/image/AiImageDO.java
@@ -2,7 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.image;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import com.baomidou.mybatisplus.annotation.KeySequence;
@@ -53,9 +53,15 @@ public class AiImageDO extends BaseDO {
*/
private String platform;
/**
- * 模型
+ * 模型编号
*
- * 冗余 {@link AiChatModelDO#getModel()}
+ * 关联 {@link AiModelDO#getId()}
+ */
+ private Long modelId;
+ /**
+ * 模型标识
+ *
+ * 冗余 {@link AiModelDO#getModel()}
*/
private String model;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
index 638a8ba50b..e1327a50ef 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDO.java
@@ -2,15 +2,12 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
-import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
-import java.util.List;
-
/**
* AI 知识库 DO
*
@@ -26,12 +23,6 @@ public class AiKnowledgeDO extends BaseDO {
*/
@TableId
private Long id;
- /**
- * 用户编号
- *
- * 关联 AdminUserDO 的 userId 字段
- */
- private Long userId;
/**
* 知识库名称
*/
@@ -42,20 +33,17 @@ public class AiKnowledgeDO extends BaseDO {
private String description;
/**
- * 可见权限,选择哪些人可见
- *
- * -1 所有人可见,其他为各自用户编号
+ * 向量模型编号
+ *
+ * 关联 {@link AiModelDO#getId()}
*/
- @TableField(typeHandler = LongListTypeHandler.class)
- private List visibilityPermissions;
- /**
- * 嵌入模型编号
- */
- private Long modelId;
+ private Long embeddingModelId;
/**
* 模型标识
+ *
+ * 冗余 {@link AiModelDO#getModel()}
*/
- private String model;
+ private String embeddingModel;
/**
* topK
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
index ee8bfd5aab..ac014e926b 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeDocumentDO.java
@@ -2,7 +2,6 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.knowledge;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
-import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@@ -30,57 +29,35 @@ public class AiKnowledgeDocumentDO extends BaseDO {
*/
private Long knowledgeId;
/**
- * 文件名称
+ * 文档名称
*/
private String name;
- /**
- * 内容
- */
- private String content;
/**
* 文件 URL
*/
private String url;
+ /**
+ * 内容
+ */
+ private String content;
+ /**
+ * 文档长度
+ */
+ private Integer contentLength;
+
/**
* 文档 token 数量
*/
private Integer tokens;
/**
- * 文档字符数
+ * 分片最大 Token 数
*/
- private Integer wordCount;
-
-
- // ========== 自定义分段所用参数 ==========
- // TODO @新:3)defaultChunkSize、defaultChunkSize、minChunkSizeChars、maxNumChunks 这几个字段的命名,可能要微信一起讨论下。尽量命名保持风格统一哈。
- /**
- * 每个文本块的目标 token 数
- */
- private Integer defaultSegmentTokens;
- /**
- * 每个文本块的最小字符数
- */
- private Integer minSegmentWordCount;
- /**
- * 低于此值的块会被丢弃
- */
- private Integer minChunkLengthToEmbed;
- /**
- * 最大块数
- */
- private Integer maxNumSegments;
- /**
- * 分块是否保留分隔符
- */
- private Boolean keepSeparator;
- // ===================================
+ private Integer segmentMaxTokens;
/**
- * 切片状态
- *
- * 枚举 {@link AiKnowledgeDocumentStatusEnum}
+ * 召回次数
*/
- private Integer sliceStatus;
+ private Integer retrievalCount;
/**
* 状态
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
index b08e960d14..cccbd6846b 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/knowledge/AiKnowledgeSegmentDO.java
@@ -17,17 +17,16 @@ import lombok.Data;
@Data
public class AiKnowledgeSegmentDO extends BaseDO {
- public static final String FIELD_KNOWLEDGE_ID = "knowledgeId";
+ /**
+ * 向量库的编号 - 空值
+ */
+ public static final String VECTOR_ID_EMPTY = "";
/**
* 编号
*/
@TableId
private Long id;
- /**
- * 向量库的编号
- */
- private String vectorId;
/**
* 知识库编号
*
@@ -45,13 +44,24 @@ public class AiKnowledgeSegmentDO extends BaseDO {
*/
private String content;
/**
- * 字符数
+ * 切片内容长度
*/
- private Integer wordCount;
+ private Integer contentLength;
+
+ /**
+ * 向量库的编号
+ */
+ private String vectorId;
/**
* token 数量
*/
private Integer tokens;
+
+ /**
+ * 召回次数
+ */
+ private Integer retrievalCount;
+
/**
* 状态
*
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
index b9768529f1..6dd5d44302 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/mindmap/AiMindMapDO.java
@@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.mindmap;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
@@ -36,6 +37,12 @@ public class AiMindMapDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
+ /**
+ * 模型编号
+ *
+ * 关联 {@link AiModelDO#getId()}
+ */
+ private Long modelId;
/**
* 模型
*/
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiApiKeyDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiApiKeyDO.java
index e251d55c85..346811f0d5 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiApiKeyDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiApiKeyDO.java
@@ -16,7 +16,6 @@ import lombok.*;
@TableName("ai_api_key")
@KeySequence("ai_chat_conversation_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data
-@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java
index f5ed533a92..bb6a3ca48d 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatRoleDO.java
@@ -2,11 +2,16 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.model;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.framework.mybatis.core.type.LongListTypeHandler;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import com.baomidou.mybatisplus.annotation.KeySequence;
+import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
+import java.util.List;
+
/**
* AI 聊天角色 DO
*
@@ -16,7 +21,6 @@ import lombok.*;
@TableName(value = "ai_chat_role", autoResultMap = true)
@KeySequence("ai_chat_role_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data
-@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
@@ -58,10 +62,25 @@ public class AiChatRoleDO extends BaseDO {
/**
* 模型编号
*
- * 关联 {@link AiChatModelDO#getId()} 字段
+ * 关联 {@link AiModelDO#getId()} 字段
*/
private Long modelId;
+ /**
+ * 引用的知识库编号列表
+ *
+ * 关联 {@link AiKnowledgeDO#getId()} 字段
+ */
+ @TableField(typeHandler = LongListTypeHandler.class)
+ private List knowledgeIds;
+ /**
+ * 引用的工具编号列表
+ *
+ * 关联 {@link AiToolDO#getId()} 字段
+ */
+ @TableField(typeHandler = LongListTypeHandler.class)
+ private List toolIds;
+
/**
* 是否公开
*
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java
similarity index 76%
rename from yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java
rename to yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java
index 7197f8b58f..b39320291b 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiChatModelDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiModelDO.java
@@ -1,5 +1,6 @@
package cn.iocoder.yudao.module.ai.dal.dataobject.model;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
@@ -9,21 +10,20 @@ import com.baomidou.mybatisplus.annotation.TableName;
import lombok.*;
/**
- * AI 聊天模型 DO
+ * AI 模型 DO
*
- * 默认聊天模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
+ * 默认模型:{@link #status} 为开启,并且 {@link #sort} 排序第一
*
* @author fansili
* @since 2024/4/24 19:39
*/
-@TableName("ai_chat_model")
-@KeySequence("ai_chat_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
+@TableName("ai_model")
+@KeySequence("ai_model_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
@Data
-@EqualsAndHashCode(callSuper = true)
@Builder
@NoArgsConstructor
@AllArgsConstructor
-public class AiChatModelDO extends BaseDO {
+public class AiModelDO extends BaseDO {
/**
* 编号
@@ -50,6 +50,12 @@ public class AiChatModelDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
+ /**
+ * 类型
+ *
+ * 枚举 {@link AiModelTypeEnum}
+ */
+ private Integer type;
/**
* 排序值
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiToolDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiToolDO.java
new file mode 100644
index 0000000000..7773e978cc
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/model/AiToolDO.java
@@ -0,0 +1,48 @@
+package cn.iocoder.yudao.module.ai.dal.dataobject.model;
+
+import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.service.model.tool.DirectoryListToolFunction;
+import cn.iocoder.yudao.module.ai.service.model.tool.WeatherQueryToolFunction;
+import com.baomidou.mybatisplus.annotation.KeySequence;
+import com.baomidou.mybatisplus.annotation.TableId;
+import com.baomidou.mybatisplus.annotation.TableName;
+import lombok.*;
+
+/**
+ * AI 工具 DO
+ *
+ * @author 芋道源码
+ */
+@TableName("ai_tool")
+@KeySequence("ai_tool_seq") // 用于 Oracle、PostgreSQL、Kingbase、DB2、H2 数据库的主键自增。如果是 MySQL 等数据库,可不写。
+@Data
+@Builder
+@NoArgsConstructor
+@AllArgsConstructor
+public class AiToolDO extends BaseDO {
+
+ /**
+ * 工具编号
+ */
+ @TableId
+ private Long id;
+ /**
+ * 工具名称
+ *
+ * 对应 Bean 的名字,例如说:
+ * 1. {@link DirectoryListToolFunction} 的 Bean 名字是 directory_list
+ * 2. {@link WeatherQueryToolFunction} 的 Bean 名字是 weather_query
+ */
+ private String name;
+ /**
+ * 工具描述
+ */
+ private String description;
+ /**
+ * 状态
+ *
+ * 枚举 {@link cn.iocoder.yudao.framework.common.enums.CommonStatusEnum}
+ */
+ private Integer status;
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/music/AiMusicDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/music/AiMusicDO.java
index e03d62c162..bfa7394ddd 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/music/AiMusicDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/music/AiMusicDO.java
@@ -84,6 +84,7 @@ public class AiMusicDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
+ // TODO @芋艿:modelId?
/**
* 模型
*/
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java
index 0d6f9c5e64..e07f994aad 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/dataobject/write/AiWriteDO.java
@@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.dal.dataobject.write;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.mybatis.core.dataobject.BaseDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.enums.DictTypeConstants;
import cn.iocoder.yudao.module.ai.enums.write.AiWriteTypeEnum;
import com.baomidou.mybatisplus.annotation.KeySequence;
import com.baomidou.mybatisplus.annotation.TableId;
@@ -44,6 +46,12 @@ public class AiWriteDO extends BaseDO {
* 枚举 {@link AiPlatformEnum}
*/
private String platform;
+ /**
+ * 模型编号
+ *
+ * 关联 {@link AiModelDO#getId()}
+ */
+ private Long modelId;
/**
* 模型
*/
@@ -66,25 +74,25 @@ public class AiWriteDO extends BaseDO {
/**
* 长度提示词
*
- * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LENGTH}
+ * 字典:{@link DictTypeConstants#AI_WRITE_LENGTH}
*/
private Integer length;
/**
* 格式提示词
*
- * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_FORMAT}
+ * 字典:{@link DictTypeConstants#AI_WRITE_FORMAT}
*/
private Integer format;
/**
* 语气提示词
*
- * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_TONE}
+ * 字典:{@link DictTypeConstants#AI_WRITE_TONE}
*/
private Integer tone;
/**
* 语言提示词
*
- * 字典:{@link cn.iocoder.yudao.module.ai.enums.DictTypeConstants#AI_WRITE_LANGUAGE}
+ * 字典:{@link DictTypeConstants#AI_WRITE_LANGUAGE}
*/
private Integer language;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeDocumentMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeDocumentMapper.java
index 7692d1cede..11a76cc57b 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeDocumentMapper.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeDocumentMapper.java
@@ -5,10 +5,14 @@ import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
+import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import org.apache.ibatis.annotations.Mapper;
+import java.util.Collection;
+import java.util.List;
+
/**
- * AI 知识库-文档 Mapper
+ * AI 知识库文档 Mapper
*
* @author xiaoxin
*/
@@ -17,8 +21,19 @@ public interface AiKnowledgeDocumentMapper extends BaseMapperX selectPage(AiKnowledgeDocumentPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX()
+ .eqIfPresent(AiKnowledgeDocumentDO::getKnowledgeId, reqVO.getKnowledgeId())
.likeIfPresent(AiKnowledgeDocumentDO::getName, reqVO.getName())
.orderByDesc(AiKnowledgeDocumentDO::getId));
}
+ default void updateRetrievalCountIncr(Collection ids) {
+ update(new LambdaUpdateWrapper()
+ .setSql(" retrieval_count = retrieval_count + 1")
+ .in(AiKnowledgeDocumentDO::getId, ids));
+ }
+
+ default List selectListByStatus(Integer status) {
+ return selectList(AiKnowledgeDocumentDO::getStatus, status);
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeMapper.java
index f07a9a2afa..3433c0b973 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeMapper.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeMapper.java
@@ -1,6 +1,5 @@
package cn.iocoder.yudao.module.ai.dal.mysql.knowledge;
-import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
@@ -8,19 +7,26 @@ import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnow
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import org.apache.ibatis.annotations.Mapper;
+import java.util.List;
+
/**
- * AI 知识库基础信息 Mapper
+ * AI 知识库 Mapper
*
* @author xiaoxin
*/
@Mapper
public interface AiKnowledgeMapper extends BaseMapperX {
- default PageResult selectPage(Long userId, AiKnowledgePageReqVO pageReqVO) {
+ default PageResult selectPage(AiKnowledgePageReqVO pageReqVO) {
return selectPage(pageReqVO, new LambdaQueryWrapperX()
- .eq(AiKnowledgeDO::getStatus, CommonStatusEnum.ENABLE.getStatus())
.likeIfPresent(AiKnowledgeDO::getName, pageReqVO.getName())
- .and(e -> e.apply("FIND_IN_SET(" + userId + ",visibility_permissions)").or(m -> m.apply("FIND_IN_SET(-1,visibility_permissions)")))
+ .eqIfPresent(AiKnowledgeDO::getStatus, pageReqVO.getStatus())
+ .betweenIfPresent(AiKnowledgeDO::getCreateTime, pageReqVO.getCreateTime())
.orderByDesc(AiKnowledgeDO::getId));
}
+
+ default List selectListByStatus(Integer status) {
+ return selectList(AiKnowledgeDO::getStatus, status);
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
index 094f19b52e..00bacd9665 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/knowledge/AiKnowledgeSegmentMapper.java
@@ -3,14 +3,19 @@ package cn.iocoder.yudao.module.ai.dal.mysql.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.MPJLambdaWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
+import com.github.yulichang.wrapper.MPJLambdaWrapper;
import org.apache.ibatis.annotations.Mapper;
+import java.util.Collection;
import java.util.List;
/**
- * AI 知识库-分片 Mapper
+ * AI 知识库分片 Mapper
*
* @author xiaoxin
*/
@@ -20,8 +25,8 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX selectPage(AiKnowledgeSegmentPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX()
.eq(AiKnowledgeSegmentDO::getDocumentId, reqVO.getDocumentId())
+ .likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getContent())
.eqIfPresent(AiKnowledgeSegmentDO::getStatus, reqVO.getStatus())
- .likeIfPresent(AiKnowledgeSegmentDO::getContent, reqVO.getKeyword())
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
@@ -31,4 +36,32 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX selectListByDocumentId(Long documentId) {
+ return selectList(new LambdaQueryWrapperX()
+ .eq(AiKnowledgeSegmentDO::getDocumentId, documentId)
+ .orderByDesc(AiKnowledgeSegmentDO::getId));
+ }
+
+ default List selectListByKnowledgeIdAndStatus(Long knowledgeId, Integer status) {
+ return selectList(AiKnowledgeSegmentDO::getKnowledgeId, knowledgeId,
+ AiKnowledgeSegmentDO::getStatus, status);
+ }
+
+ default List selectProcessList(Collection documentIds) {
+ MPJLambdaWrapper wrapper = new MPJLambdaWrapperX()
+ .selectAs(AiKnowledgeSegmentDO::getDocumentId, AiKnowledgeSegmentProcessRespVO::getDocumentId)
+ .selectCount(AiKnowledgeSegmentDO::getId, "count")
+ .select("COUNT(CASE WHEN vector_id > '" + AiKnowledgeSegmentDO.VECTOR_ID_EMPTY
+ + "' THEN 1 ELSE NULL END) AS embeddingCount")
+ .in(AiKnowledgeSegmentDO::getDocumentId, documentIds)
+ .groupBy(AiKnowledgeSegmentDO::getDocumentId);
+ return selectJoinList(AiKnowledgeSegmentProcessRespVO.class, wrapper);
+ }
+
+ default void updateRetrievalCountIncrByIds(List ids) {
+ update(new LambdaUpdateWrapper()
+ .setSql(" retrieval_count = retrieval_count + 1")
+ .in(AiKnowledgeSegmentDO::getId, ids));
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java
new file mode 100644
index 0000000000..bfe2caf52a
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatMapper.java
@@ -0,0 +1,47 @@
+package cn.iocoder.yudao.module.ai.dal.mysql.model;
+
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import org.apache.ibatis.annotations.Mapper;
+
+import javax.annotation.Nullable;
+import java.util.List;
+
+/**
+ * API 模型 Mapper
+ *
+ * @author fansili
+ */
+@Mapper
+public interface AiChatMapper extends BaseMapperX {
+
+ default AiModelDO selectFirstByStatus(Integer type, Integer status) {
+ return selectOne(new QueryWrapperX()
+ .eq("type", type)
+ .eq("status", status)
+ .limitN(1)
+ .orderByAsc("sort"));
+ }
+
+ default PageResult selectPage(AiModelPageReqVO reqVO) {
+ return selectPage(reqVO, new LambdaQueryWrapperX()
+ .likeIfPresent(AiModelDO::getName, reqVO.getName())
+ .eqIfPresent(AiModelDO::getModel, reqVO.getModel())
+ .eqIfPresent(AiModelDO::getPlatform, reqVO.getPlatform())
+ .orderByAsc(AiModelDO::getSort));
+ }
+
+ default List selectListByStatusAndType(Integer status, Integer type,
+ @Nullable String platform) {
+ return selectList(new LambdaQueryWrapperX()
+ .eq(AiModelDO::getStatus, status)
+ .eq(AiModelDO::getType, type)
+ .eqIfPresent(AiModelDO::getPlatform, platform)
+ .orderByAsc(AiModelDO::getSort));
+ }
+
+}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java
deleted file mode 100644
index a3136fa9f6..0000000000
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiChatModelMapper.java
+++ /dev/null
@@ -1,43 +0,0 @@
-package cn.iocoder.yudao.module.ai.dal.mysql.model;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
-import cn.iocoder.yudao.framework.mybatis.core.query.QueryWrapperX;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import org.apache.ibatis.annotations.Mapper;
-
-import java.util.Collection;
-import java.util.List;
-
-/**
- * API 聊天模型 Mapper
- *
- * @author fansili
- */
-@Mapper
-public interface AiChatModelMapper extends BaseMapperX {
-
- default AiChatModelDO selectFirstByStatus(Integer status) {
- return selectOne(new QueryWrapperX()
- .eq("status", status)
- .limitN(1)
- .orderByAsc("sort"));
- }
-
- default PageResult selectPage(AiChatModelPageReqVO reqVO) {
- return selectPage(reqVO, new LambdaQueryWrapperX()
- .likeIfPresent(AiChatModelDO::getName, reqVO.getName())
- .eqIfPresent(AiChatModelDO::getModel, reqVO.getModel())
- .eqIfPresent(AiChatModelDO::getPlatform, reqVO.getPlatform())
- .orderByAsc(AiChatModelDO::getSort));
- }
-
- default List selectList(Integer status) {
- return selectList(new LambdaQueryWrapperX()
- .eq(AiChatModelDO::getStatus, status)
- .orderByAsc(AiChatModelDO::getSort));
- }
-
-}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiToolMapper.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiToolMapper.java
new file mode 100644
index 0000000000..d5d296692a
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/dal/mysql/model/AiToolMapper.java
@@ -0,0 +1,35 @@
+package cn.iocoder.yudao.module.ai.dal.mysql.model;
+
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
+import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolPageReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
+import org.apache.ibatis.annotations.Mapper;
+
+import java.util.List;
+
+/**
+ * AI 工具 Mapper
+ *
+ * @author 芋道源码
+ */
+@Mapper
+public interface AiToolMapper extends BaseMapperX {
+
+ default PageResult selectPage(AiToolPageReqVO reqVO) {
+ return selectPage(reqVO, new LambdaQueryWrapperX()
+ .likeIfPresent(AiToolDO::getName, reqVO.getName())
+ .eqIfPresent(AiToolDO::getDescription, reqVO.getDescription())
+ .eqIfPresent(AiToolDO::getStatus, reqVO.getStatus())
+ .betweenIfPresent(AiToolDO::getCreateTime, reqVO.getCreateTime())
+ .orderByDesc(AiToolDO::getId));
+ }
+
+ default List selectListByStatus(Integer status) {
+ return selectList(new LambdaQueryWrapperX()
+ .eq(AiToolDO::getStatus, status)
+ .orderByDesc(AiToolDO::getId));
+ }
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/framework/web/package-info.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/framework/web/package-info.java
index 09de7263c5..e979056d4e 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/framework/web/package-info.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/framework/web/package-info.java
@@ -1,4 +1,4 @@
/**
- * crm 模块的 web 拓展封装
+ * ai 模块的 web 拓展封装
*/
-package cn.iocoder.yudao.module.crm.framework.web;
+package cn.iocoder.yudao.module.ai.framework.web;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
index 8f094087f1..6c35571c8f 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatConversationServiceImpl.java
@@ -4,17 +4,18 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ObjectUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationCreateMyReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.conversation.AiChatConversationUpdateMyReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatConversationMapper;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@@ -44,7 +45,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
private AiChatConversationMapper chatConversationMapper;
@Resource
- private AiChatModelService chatModalService;
+ private AiModelService modalService;
@Resource
private AiChatRoleService chatRoleService;
@Resource
@@ -54,9 +55,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
public Long createChatConversationMy(AiChatConversationCreateMyReqVO createReqVO, Long userId) {
// 1.1 获得 AiChatRoleDO 聊天角色
AiChatRoleDO role = createReqVO.getRoleId() != null ? chatRoleService.validateChatRole(createReqVO.getRoleId()) : null;
- // 1.2 获得 AiChatModelDO 聊天模型
- AiChatModelDO model = role != null && role.getModelId() != null ? chatModalService.validateChatModel(role.getModelId())
- : chatModalService.getRequiredDefaultChatModel();
+ // 1.2 获得 AiModelDO 聊天模型
+ AiModelDO model = role != null && role.getModelId() != null ? modalService.validateModel(role.getModelId())
+ : modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
Assert.notNull(model, "必须找到默认模型");
validateChatModel(model);
@@ -67,7 +68,7 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
// 2. 创建 AiChatConversationDO 聊天对话
AiChatConversationDO conversation = new AiChatConversationDO().setUserId(userId).setPinned(false)
- .setModelId(model.getId()).setModel(model.getModel()).setKnowledgeId(createReqVO.getKnowledgeId())
+ .setModelId(model.getId()).setModel(model.getModel())
.setTemperature(model.getTemperature()).setMaxTokens(model.getMaxTokens()).setMaxContexts(model.getMaxContexts());
if (role != null) {
conversation.setTitle(role.getName()).setRoleId(role.getId()).setSystemMessage(role.getSystemMessage());
@@ -86,9 +87,9 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
// 1.2 校验模型是否存在(修改模型的情况)
- AiChatModelDO model = null;
+ AiModelDO model = null;
if (updateReqVO.getModelId() != null) {
- model = chatModalService.validateChatModel(updateReqVO.getModelId());
+ model = modalService.validateModel(updateReqVO.getModelId());
}
// 1.3 校验知识库是否存在
@@ -139,10 +140,11 @@ public class AiChatConversationServiceImpl implements AiChatConversationService
chatConversationMapper.deleteById(id);
}
- private void validateChatModel(AiChatModelDO model) {
+ private void validateChatModel(AiModelDO model) {
if (ObjectUtil.isAllNotEmpty(model.getTemperature(), model.getMaxTokens(), model.getMaxContexts())) {
return;
}
+ Assert.equals(model.getType(), AiModelTypeEnum.CHAT.getType(), "模型类型不正确:" + model);
throw exception(CHAT_CONVERSATION_MODEL_ERROR);
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
index d332fbf1a6..f310ba69fd 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java
@@ -10,19 +10,24 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageSendRespVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatConversationDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.chat.AiChatMessageDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
import cn.iocoder.yudao.module.ai.dal.mysql.chat.AiChatMessageMapper;
-import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeDocumentService;
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiToolService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@@ -34,18 +39,19 @@ import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
-import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
import java.time.LocalDateTime;
import java.util.*;
+import java.util.stream.Collectors;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertSet;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
@@ -58,138 +64,199 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_N
@Slf4j
public class AiChatMessageServiceImpl implements AiChatMessageService {
+ /**
+ * 知识库转 {@link UserMessage} 的内容模版
+ */
+ private static final String KNOWLEDGE_USER_MESSAGE_TEMPLATE = "使用 标记中的内容作为本次对话的参考:\n\n" +
+ "%s\n\n" + // 多个 的拼接
+ "回答要求:\n- 避免提及你是从 获取的知识。";
+
@Resource
private AiChatMessageMapper chatMessageMapper;
@Resource
private AiChatConversationService chatConversationService;
@Resource
- private AiChatModelService chatModalService;
+ private AiChatRoleService chatRoleService;
@Resource
- private AiApiKeyService apiKeyService;
+ private AiModelService modalService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
+ @Resource
+ private AiKnowledgeDocumentService knowledgeDocumentService;
+ @Resource
+ private AiToolService toolService;
@Transactional(rollbackFor = Exception.class)
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
// 1.1 校验对话存在
- AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+ AiChatConversationDO conversation = chatConversationService
+ .validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
- AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
- ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+ AiModelDO model = modalService.validateModel(conversation.getModelId());
+ ChatModel chatModel = modalService.getChatModel(model.getId());
- // 2. 插入 user 发送消息
+ // 2. 知识库找回
+ List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+ conversation);
+
+ // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
- userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+ null);
// 3.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
- userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+ knowledgeSegments);
- // 3.2 召回段落
- List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
- // 3.3 创建 chat 需要的 Prompt
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+ // 3.2 创建 chat 需要的 Prompt
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
ChatResponse chatResponse = chatModel.call(prompt);
- // 3.4 段式返回
- String newContent = chatResponse.getResult().getOutput().getContent();
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId)).setContent(newContent));
- return new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
- .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent));
+ // 3.3 更新响应内容
+ String newContent = chatResponse.getResult().getOutput().getText();
+ chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
+ // 3.4 响应结果
+ List segments = BeanUtils.toBean(knowledgeSegments,
+ AiChatMessageRespVO.KnowledgeSegment.class,
+ segment -> {
+ AiKnowledgeDocumentDO document = knowledgeDocumentService
+ .getKnowledgeDocument(segment.getDocumentId());
+ segment.setDocumentName(document != null ? document.getName() : null);
+ });
+ return new AiChatMessageSendRespVO()
+ .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+ .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+ .setContent(newContent).setSegments(segments));
}
@Override
- public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO, Long userId) {
+ public Flux> sendChatMessageStream(AiChatMessageSendReqVO sendReqVO,
+ Long userId) {
// 1.1 校验对话存在
- AiChatConversationDO conversation = chatConversationService.validateChatConversationExists(sendReqVO.getConversationId());
+ AiChatConversationDO conversation = chatConversationService
+ .validateChatConversationExists(sendReqVO.getConversationId());
if (ObjUtil.notEqual(conversation.getUserId(), userId)) {
throw exception(CHAT_CONVERSATION_NOT_EXISTS);
}
List historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
- AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
- StreamingChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+ AiModelDO model = modalService.validateModel(conversation.getModelId());
+ StreamingChatModel chatModel = modalService.getChatModel(model.getId());
- // 2. 插入 user 发送消息
+ // 2. 知识库找回
+ List knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
+ conversation);
+
+ // 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
- userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.USER, sendReqVO.getContent(), sendReqVO.getUseContext(),
+ null);
- // 3.1 插入 assistant 接收消息
+ // 4.1 插入 assistant 接收消息
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
- userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
+ userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext(),
+ knowledgeSegments);
-
- // 3.2 召回段落
- List segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
-
- // 3.3 构建 Prompt,并进行调用
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+ // 4.2 构建 Prompt,并进行调用
+ Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
Flux streamResponse = chatModel.stream(prompt);
- // 3.4 流式返回
- // TODO 注意:Schedulers.immediate() 目的是,避免默认 Schedulers.parallel() 并发消费 chunk 导致 SSE 响应前端会乱序问题
+ // 4.3 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
- String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
+ // 处理知识库的返回,只有首次才有
+ List segments = null;
+ if (StrUtil.isEmpty(contentBuffer)) {
+ segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
+ segment -> TenantUtils.executeIgnore(() -> {
+ AiKnowledgeDocumentDO document = knowledgeDocumentService
+ .getKnowledgeDocument(segment.getDocumentId());
+ segment.setDocumentName(document != null ? document.getName() : null);
+ }));
+ }
+ // 响应结果
+ String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
contentBuffer.append(newContent);
- // 响应结果
- return success(new AiChatMessageSendRespVO().setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
- .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class).setContent(newContent)));
+ return success(new AiChatMessageSendRespVO()
+ .setSend(BeanUtils.toBean(userMessage, AiChatMessageSendRespVO.Message.class))
+ .setReceive(BeanUtils.toBean(assistantMessage, AiChatMessageSendRespVO.Message.class)
+ .setContent(newContent).setSegments(segments)));
}).doOnComplete(() -> {
// 忽略租户,因为 Flux 异步无法透传租户
- TenantUtils.executeIgnore(() ->
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setSegmentIds(convertList(segmentList, AiKnowledgeSegmentDO::getId))
- .setContent(contentBuffer.toString())));
+ TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+ new AiChatMessageDO().setId(assistantMessage.getId()).setContent(contentBuffer.toString())));
}).doOnError(throwable -> {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
// 忽略租户,因为 Flux 异步无法透传租户
- TenantUtils.executeIgnore(() ->
- chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
+ TenantUtils.executeIgnore(() -> chatMessageMapper.updateById(
+ new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage())));
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
}
- private List recallSegment(String content, Long knowledgeId) {
- if (Objects.isNull(knowledgeId)) {
+ private List recallKnowledgeSegment(String content,
+ AiChatConversationDO conversation) {
+ // 1. 查询聊天角色
+ if (conversation == null || conversation.getRoleId() == null) {
return Collections.emptyList();
}
- return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
- }
-
- private Prompt buildPrompt(AiChatConversationDO conversation, List messages,List segmentList,
- AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
- // 1. 构建 Prompt Message 列表
- List chatMessages = new ArrayList<>();
-
- // 1.1 召回内容消息构建
- if (CollUtil.isNotEmpty(segmentList)) {
- PromptTemplate promptTemplate = new PromptTemplate(AiChatRoleEnum.AI_KNOWLEDGE_ROLE.getSystemMessage());
- StringBuilder infoBuilder = StrUtil.builder();
- segmentList.forEach(segment -> infoBuilder.append(System.lineSeparator()).append(segment.getContent()));
- Message message = promptTemplate.createMessage(Map.of("info", infoBuilder.toString()));
- chatMessages.add(message);
+ AiChatRoleDO role = chatRoleService.getChatRole(conversation.getRoleId());
+ if (role == null || CollUtil.isEmpty(role.getKnowledgeIds())) {
+ return Collections.emptyList();
}
- // 1.2 system context 角色设定
+ // 2. 遍历找回
+ List knowledgeSegments = new ArrayList<>();
+ for (Long knowledgeId : role.getKnowledgeIds()) {
+ knowledgeSegments.addAll(knowledgeSegmentService.searchKnowledgeSegment(new AiKnowledgeSegmentSearchReqBO()
+ .setKnowledgeId(knowledgeId).setContent(content)));
+ }
+ return knowledgeSegments;
+ }
+
+ private Prompt buildPrompt(AiChatConversationDO conversation, List messages,
+ List knowledgeSegments,
+ AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+ List chatMessages = new ArrayList<>();
+ // 1.1 System Context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
chatMessages.add(new SystemMessage(conversation.getSystemMessage()));
}
- // 1.3 history message 历史消息
+
+ // 1.2 历史 history message 历史消息
List contextMessages = filterContextMessages(messages, conversation, sendReqVO);
- contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
- // 1.4 user message 新发送消息
+ contextMessages
+ .forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
+
+ // 1.3 当前 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
- // 2. 构建 ChatOptions 对象
+ // 1.4 知识库,通过 UserMessage 实现
+ if (CollUtil.isNotEmpty(knowledgeSegments)) {
+ String reference = knowledgeSegments.stream()
+ .map(segment -> "" + segment.getContent() + "")
+ .collect(Collectors.joining("\n\n"));
+ chatMessages.add(new UserMessage(String.format(KNOWLEDGE_USER_MESSAGE_TEMPLATE, reference)));
+ }
+
+ // 2.1 查询 tool 工具
+ Set toolNames = null;
+ if (conversation.getRoleId() != null) {
+ AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
+ if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
+ toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
+ }
+ }
+ // 2.2 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
- conversation.getTemperature(), conversation.getMaxTokens());
+ conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
return new Prompt(chatMessages, chatOptions);
}
@@ -204,8 +271,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
* @return 消息上下文
*/
private List filterContextMessages(List messages,
- AiChatConversationDO conversation,
- AiChatMessageSendReqVO sendReqVO) {
+ AiChatConversationDO conversation,
+ AiChatMessageSendReqVO sendReqVO) {
if (conversation.getMaxContexts() == null || ObjUtil.notEqual(sendReqVO.getUseContext(), Boolean.TRUE)) {
return Collections.emptyList();
}
@@ -216,7 +283,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
continue;
}
AiChatMessageDO userMessage = CollUtil.get(messages, i - 1);
- if (userMessage == null || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
+ if (userMessage == null
+ || ObjUtil.notEqual(assistantMessage.getReplyId(), userMessage.getId())
|| StrUtil.isEmpty(assistantMessage.getContent())) {
continue;
}
@@ -233,11 +301,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
private AiChatMessageDO createChatMessage(Long conversationId, Long replyId,
- AiChatModelDO model, Long userId, Long roleId,
- MessageType messageType, String content, Boolean useContext) {
+ AiModelDO model, Long userId, Long roleId,
+ MessageType messageType, String content, Boolean useContext,
+ List knowledgeSegments) {
AiChatMessageDO message = new AiChatMessageDO().setConversationId(conversationId).setReplyId(replyId)
.setModel(model.getModel()).setModelId(model.getId()).setUserId(userId).setRoleId(roleId)
- .setType(messageType.getValue()).setContent(content).setUseContext(useContext);
+ .setType(messageType.getValue()).setContent(content).setUseContext(useContext)
+ .setSegmentIds(convertList(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getId));
message.setCreateTime(LocalDateTime.now());
chatMessageMapper.insert(message);
return message;
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
index e8532a5762..60ca9ac996 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/image/AiImageServiceImpl.java
@@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.service.image;
import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.codec.Base64;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
@@ -12,15 +13,19 @@ import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.image.vo.*;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePublicPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.image.AiImageMapper;
import cn.iocoder.yudao.module.ai.enums.image.AiImageStatusEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
-import com.alibaba.cloud.ai.tongyi.image.TongYiImagesOptions;
+import com.alibaba.cloud.ai.dashscope.image.DashScopeImageOptions;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageModel;
@@ -54,15 +59,15 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
@Slf4j
public class AiImageServiceImpl implements AiImageService {
+ @Resource
+ private AiModelService modelService;
+
@Resource
private AiImageMapper imageMapper;
@Resource
private FileApi fileApi;
- @Resource
- private AiApiKeyService apiKeyService;
-
@Override
public PageResult getImagePageMy(Long userId, AiImagePageReqVO pageReqVO) {
return imageMapper.selectPageMy(userId, pageReqVO);
@@ -88,23 +93,31 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
- // 1. 保存数据库
- AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
- .setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
+ // 1. 校验模型
+ AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
+
+ // 2. 保存数据库
+ AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId)
+ .setPlatform(model.getPlatform()).setModelId(model.getId()).setModel(model.getModel())
+ .setPublicStatus(false).setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus());
imageMapper.insert(image);
- // 2. 异步绘制,后续前端通过返回的 id 进行轮询结果
- getSelf().executeDrawImage(image, drawReqVO);
+
+ // 3. 异步绘制,后续前端通过返回的 id 进行轮询结果
+ getSelf().executeDrawImage(image, drawReqVO, model);
return image.getId();
}
@Async
- public void executeDrawImage(AiImageDO image, AiImageDrawReqVO req) {
+ public void executeDrawImage(AiImageDO image, AiImageDrawReqVO reqVO, AiModelDO model) {
try {
// 1.1 构建请求
- ImageOptions request = buildImageOptions(req);
+ ImageOptions request = buildImageOptions(reqVO, model);
// 1.2 执行请求
- ImageModel imageModel = apiKeyService.getImageModel(AiPlatformEnum.validatePlatform(req.getPlatform()));
- ImageResponse response = imageModel.call(new ImagePrompt(req.getPrompt(), request));
+ ImageModel imageModel = modelService.getImageModel(model.getId());
+ ImageResponse response = imageModel.call(new ImagePrompt(reqVO.getPrompt(), request));
+ if (response.getResult() == null) {
+ throw new IllegalArgumentException("生成结果为空");
+ }
// 2. 上传到文件服务
String b64Json = response.getResult().getOutput().getB64Json();
@@ -116,49 +129,49 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.updateById(new AiImageDO().setId(image.getId()).setStatus(AiImageStatusEnum.SUCCESS.getStatus())
.setPicUrl(filePath).setFinishTime(LocalDateTime.now()));
} catch (Exception ex) {
- log.error("[doDall][image({}) 生成异常]", image, ex);
+ log.error("[executeDrawImage][image({}) 生成异常]", image, ex);
imageMapper.updateById(new AiImageDO().setId(image.getId())
.setStatus(AiImageStatusEnum.FAIL.getStatus())
.setErrorMessage(ex.getMessage()).setFinishTime(LocalDateTime.now()));
}
}
- private static ImageOptions buildImageOptions(AiImageDrawReqVO draw) {
- if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
+ private static ImageOptions buildImageOptions(AiImageDrawReqVO draw, AiModelDO model) {
+ if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.OPENAI.getPlatform())) {
// https://platform.openai.com/docs/api-reference/images/create
- return OpenAiImageOptions.builder().withModel(draw.getModel())
+ return OpenAiImageOptions.builder().withModel(model.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withStyle(MapUtil.getStr(draw.getOptions(), "style")) // 风格
.withResponseFormat("b64_json")
.build();
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
- return StabilityAiImageOptions.builder().withModel(draw.getModel())
- .withHeight(draw.getHeight()).withWidth(draw.getWidth())
- .withSeed(Long.valueOf(draw.getOptions().get("seed")))
- .withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
- .withSteps(Integer.valueOf(draw.getOptions().get("steps")))
- .withSampler(String.valueOf(draw.getOptions().get("sampler")))
- .withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
- .withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
+ return StabilityAiImageOptions.builder().model(model.getModel())
+ .height(draw.getHeight()).width(draw.getWidth())
+ .seed(Long.valueOf(draw.getOptions().get("seed")))
+ .cfgScale(Float.valueOf(draw.getOptions().get("scale")))
+ .steps(Integer.valueOf(draw.getOptions().get("steps")))
+ .sampler(String.valueOf(draw.getOptions().get("sampler")))
+ .stylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
+ .clipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
.build();
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
- return TongYiImagesOptions.builder()
- .withModel(draw.getModel()).withN(1)
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.TONG_YI.getPlatform())) {
+ return DashScopeImageOptions.builder()
+ .withModel(model.getModel()).withN(1)
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.build();
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.YI_YAN.getPlatform())) {
return QianFanImageOptions.builder()
- .withModel(draw.getModel()).withN(1)
- .withHeight(draw.getHeight()).withWidth(draw.getWidth())
+ .model(model.getModel()).N(1)
+ .height(draw.getHeight()).width(draw.getWidth())
.build();
- } else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
+ } else if (ObjUtil.equal(model.getPlatform(), AiPlatformEnum.ZHI_PU.getPlatform())) {
return ZhiPuAiImageOptions.builder()
- .withModel(draw.getModel())
+ .model(model.getModel())
.build();
}
- throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
+ throw new IllegalArgumentException("不支持的 AI 平台:" + model.getPlatform());
}
@Override
@@ -205,52 +218,56 @@ public class AiImageServiceImpl implements AiImageService {
@Override
@Transactional(rollbackFor = Exception.class)
- public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
- // 1. 保存数据库
- AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
+ public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO drawReqVO) {
+ // 1. 校验模型
+ AiModelDO model = modelService.validateModel(drawReqVO.getModelId());
+ Assert.equals(model.getPlatform(), AiPlatformEnum.MIDJOURNEY.getPlatform(), "平台不匹配");
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(model.getId());
+
+ // 2. 保存数据库
+ AiImageDO image = BeanUtils.toBean(drawReqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
- .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform());
+ .setPlatform(AiPlatformEnum.MIDJOURNEY.getPlatform()).setModelId(model.getId()).setModel(model.getName());
imageMapper.insert(image);
- // 2. 调用 Midjourney Proxy 提交任务
- List base64Array = StrUtil.isBlank(reqVO.getReferImageUrl()) ? null :
- Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl()))));
+ // 3. 调用 Midjourney Proxy 提交任务
+ List base64Array = StrUtil.isBlank(drawReqVO.getReferImageUrl()) ? null :
+ Collections.singletonList("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(drawReqVO.getReferImageUrl()))));
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
- base64Array, reqVO.getPrompt(),null,
- MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(),
- reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
+ base64Array, drawReqVO.getPrompt(),null,
+ MidjourneyApi.ImagineRequest.buildState(drawReqVO.getWidth(),
+ drawReqVO.getHeight(), drawReqVO.getVersion(), model.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
- // 3. 情况一【失败】:抛出业务异常
+ // 4.1 情况一【失败】:抛出业务异常
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
String description = imagineResponse.description().contains("quota_not_enough") ?
"账户余额不足" : imagineResponse.description();
throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
- // 4. 情况二【成功】:更新 taskId 和参数
+ // 4.2 情况二【成功】:更新 taskId 和参数
imageMapper.updateById(new AiImageDO().setId(image.getId())
- .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(reqVO)));
+ .setTaskId(imagineResponse.result()).setOptions(BeanUtil.beanToMap(drawReqVO)));
return image.getId();
}
@Override
public Integer midjourneySync() {
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台,状态在 “进行中” 的 image
- List imageList = imageMapper.selectListByStatusAndPlatform(
+ List images = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
- if (CollUtil.isEmpty(imageList)) {
+ if (CollUtil.isEmpty(images)) {
return 0;
}
// 1.2 调用 Midjourney Proxy 获取任务进展
- List taskList = midjourneyApi.getTaskList(convertSet(imageList, AiImageDO::getTaskId));
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(images.get(0).getModelId());
+ List taskList = midjourneyApi.getTaskList(convertSet(images, AiImageDO::getTaskId));
Map taskMap = convertMap(taskList, MidjourneyApi.Notify::id);
// 2. 逐个处理,更新进展
int count = 0;
- for (AiImageDO image : imageList) {
+ for (AiImageDO image : images) {
MidjourneyApi.Notify notify = taskMap.get(image.getTaskId());
if (notify == null) {
log.error("[midjourneySync][image({}) 查询不到进展]", image);
@@ -308,12 +325,12 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
- MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) {
throw exception(IMAGE_NOT_EXISTS);
}
+ MidjourneyApi midjourneyApi = modelService.getMidjourneyApi(image.getModelId());
// 1.2 检查 customId
MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
buttonX -> buttonX.customId().equals(reqVO.getCustomId()));
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentService.java
index 3de0ac01de..8ff137b331 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentService.java
@@ -1,26 +1,41 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
+
/**
- * AI 知识库-文档 Service 接口
+ * AI 知识库文档 Service 接口
*
* @author xiaoxin
*/
public interface AiKnowledgeDocumentService {
/**
- * 创建文档
+ * 创建文档(单个)
*
* @param createReqVO 文档创建 Request VO
* @return 文档编号
*/
Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO);
+ /**
+ * 创建文档(多个)
+ *
+ * @param createListReqVO 批量创建 Request VO
+ * @return 文档编号列表
+ */
+ List createKnowledgeDocumentList(AiKnowledgeDocumentCreateListReqVO createListReqVO);
/**
* 获取文档分页
@@ -30,10 +45,74 @@ public interface AiKnowledgeDocumentService {
*/
PageResult getKnowledgeDocumentPage(AiKnowledgeDocumentPageReqVO pageReqVO);
+ /**
+ * 获取文档详情
+ *
+ * @param id 文档编号
+ * @return 文档详情
+ */
+ AiKnowledgeDocumentDO getKnowledgeDocument(Long id);
+
/**
* 更新文档
*
* @param reqVO 更新信息
*/
void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO);
+
+ /**
+ * 更新文档状态
+ *
+ * @param reqVO 更新状态信息
+ */
+ void updateKnowledgeDocumentStatus(AiKnowledgeDocumentUpdateStatusReqVO reqVO);
+
+ /**
+ * 更新文档检索次数(增加 +1)
+ *
+ * @param ids 文档编号列表
+ */
+ void updateKnowledgeDocumentRetrievalCountIncr(Collection ids);
+
+ /**
+ * 删除文档
+ *
+ * @param id 文档编号
+ */
+ void deleteKnowledgeDocument(Long id);
+
+ /**
+ * 校验文档是否存在
+ *
+ * @param id 文档编号
+ * @return 文档信息
+ */
+ AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id);
+
+ /**
+ * 读取 URL 内容
+ *
+ * @param url URL
+ * @return 内容
+ */
+ String readUrl(String url);
+
+ /**
+ * 获取文档列表
+ *
+ * @param ids 文档编号列表
+ * @return 文档列表
+ */
+ List getKnowledgeDocumentList(Collection ids);
+
+ /**
+ * 获取文档 Map
+ *
+ * @param ids 文档编号列表
+ * @return 文档 Map
+ */
+ default Map getKnowledgeDocumentMap(Collection ids) {
+ return convertMap(getKnowledgeDocumentList(ids), AiKnowledgeDocumentDO::getId);
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
index ff475f92ca..2d78f94f34 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeDocumentServiceImpl.java
@@ -1,34 +1,37 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
+import cn.hutool.core.util.ObjUtil;
+import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.common.util.collection.CollectionUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentCreateListReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.document.AiKnowledgeDocumentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeDocumentCreateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeDocumentMapper;
-import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
-import cn.iocoder.yudao.module.ai.enums.knowledge.AiKnowledgeDocumentStatusEnum;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.tika.TikaDocumentReader;
import org.springframework.ai.tokenizer.TokenCountEstimator;
-import org.springframework.ai.transformer.splitter.TokenTextSplitter;
-import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.context.annotation.Lazy;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCUMENT_NOT_EXISTS;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 知识库文档 Service 实现类
@@ -40,91 +43,172 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_DOCU
public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentService {
@Resource
- private AiKnowledgeDocumentMapper documentMapper;
- @Resource
- private AiKnowledgeSegmentMapper segmentMapper;
+ private AiKnowledgeDocumentMapper knowledgeDocumentMapper;
@Resource
private TokenCountEstimator tokenCountEstimator;
+
@Resource
+ private AiKnowledgeSegmentService knowledgeSegmentService;
+ @Resource
+ @Lazy // 延迟加载,避免循环依赖
private AiKnowledgeService knowledgeService;
@Override
- @Transactional(rollbackFor = Exception.class)
public Long createKnowledgeDocument(AiKnowledgeDocumentCreateReqVO createReqVO) {
- // 0. 校验并获取向量存储实例
- VectorStore vectorStore = knowledgeService.getVectorStoreById(createReqVO.getKnowledgeId());
+ // 1. 校验参数
+ knowledgeService.validateKnowledgeExists(createReqVO.getKnowledgeId());
- // 1.1 下载文档
- TikaDocumentReader loader = new TikaDocumentReader(downloadFile(createReqVO.getUrl()));
- List documents = loader.get();
- Document document = CollUtil.getFirst(documents);
- // 1.2 文档记录入库
- String content = document.getContent();
+ // 2. 下载文档
+ String content = readUrl(createReqVO.getUrl());
+
+ // 3. 文档记录入库
AiKnowledgeDocumentDO documentDO = BeanUtils.toBean(createReqVO, AiKnowledgeDocumentDO.class)
- .setTokens(tokenCountEstimator.estimate(content)).setWordCount(content.length())
- .setStatus(CommonStatusEnum.ENABLE.getStatus()).setSliceStatus(AiKnowledgeDocumentStatusEnum.SUCCESS.getStatus());
- documentMapper.insert(documentDO);
- Long documentId = documentDO.getId();
- if (CollUtil.isEmpty(documents)) {
- return documentId;
+ .setContent(content).setContentLength(content.length()).setTokens(tokenCountEstimator.estimate(content))
+ .setStatus(CommonStatusEnum.ENABLE.getStatus());
+ knowledgeDocumentMapper.insert(documentDO);
+
+ // 4. 文档切片入库(异步)
+ knowledgeSegmentService.createKnowledgeSegmentBySplitContentAsync(documentDO.getId(), content);
+ return documentDO.getId();
+ }
+
+ @Override
+ public List createKnowledgeDocumentList(AiKnowledgeDocumentCreateListReqVO createListReqVO) {
+ // 1. 校验参数
+ knowledgeService.validateKnowledgeExists(createListReqVO.getKnowledgeId());
+
+ // 2. 下载文档
+ List contents = convertList(createListReqVO.getList(), document -> readUrl(document.getUrl()));
+
+ // 3. 文档记录入库
+ List documentDOs = new ArrayList<>(createListReqVO.getList().size());
+ for (int i = 0; i < createListReqVO.getList().size(); i++) {
+ AiKnowledgeDocumentCreateListReqVO.Document documentVO = createListReqVO.getList().get(i);
+ String content = contents.get(i);
+ documentDOs.add(BeanUtils.toBean(documentVO, AiKnowledgeDocumentDO.class)
+ .setKnowledgeId(createListReqVO.getKnowledgeId())
+ .setContent(content).setContentLength(content.length())
+ .setTokens(tokenCountEstimator.estimate(content))
+ .setSegmentMaxTokens(createListReqVO.getSegmentMaxTokens())
+ .setStatus(CommonStatusEnum.ENABLE.getStatus()));
}
+ knowledgeDocumentMapper.insertBatch(documentDOs);
- // 2 构造文本分段器
- TokenTextSplitter tokenTextSplitter = new TokenTextSplitter(createReqVO.getDefaultSegmentTokens(), createReqVO.getMinSegmentWordCount(), createReqVO.getMinChunkLengthToEmbed(),
- createReqVO.getMaxNumSegments(), createReqVO.getKeepSeparator());
- // 2.1 文档分段
- List segments = tokenTextSplitter.apply(documents);
- // 2.2 分段内容入库
- List segmentDOList = CollectionUtils.convertList(segments,
- segment -> new AiKnowledgeSegmentDO().setContent(segment.getContent()).setDocumentId(documentId)
- .setKnowledgeId(createReqVO.getKnowledgeId()).setVectorId(segment.getId())
- .setTokens(tokenCountEstimator.estimate(segment.getContent())).setWordCount(segment.getContent().length())
- .setStatus(CommonStatusEnum.ENABLE.getStatus()));
- segmentMapper.insertBatch(segmentDOList);
-
- // 3. 向量化并存储
- segments.forEach(segment -> segment.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, createReqVO.getKnowledgeId()));
- vectorStore.add(segments);
- return documentId;
+ // 4. 批量创建文档切片(异步)
+ documentDOs.forEach(documentDO -> knowledgeSegmentService
+ .createKnowledgeSegmentBySplitContentAsync(documentDO.getId(), documentDO.getContent()));
+ return convertList(documentDOs, AiKnowledgeDocumentDO::getId);
}
@Override
public PageResult getKnowledgeDocumentPage(AiKnowledgeDocumentPageReqVO pageReqVO) {
- return documentMapper.selectPage(pageReqVO);
+ return knowledgeDocumentMapper.selectPage(pageReqVO);
+ }
+
+ @Override
+ public AiKnowledgeDocumentDO getKnowledgeDocument(Long id) {
+ return knowledgeDocumentMapper.selectById(id);
}
@Override
public void updateKnowledgeDocument(AiKnowledgeDocumentUpdateReqVO reqVO) {
// 1. 校验文档是否存在
- validateKnowledgeDocumentExists(reqVO.getId());
+ AiKnowledgeDocumentDO oldDocument = validateKnowledgeDocumentExists(reqVO.getId());
+
// 2. 更新文档
AiKnowledgeDocumentDO document = BeanUtils.toBean(reqVO, AiKnowledgeDocumentDO.class);
- documentMapper.updateById(document);
+ knowledgeDocumentMapper.updateById(document);
+
+ // 3. 如果处于开启状态,并且最大 tokens 发生变化,则 segment 需要重新索引
+ if (CommonStatusEnum.isEnable(oldDocument.getStatus())
+ && reqVO.getSegmentMaxTokens() != null
+ && ObjUtil.notEqual(reqVO.getSegmentMaxTokens(), oldDocument.getSegmentMaxTokens())) {
+ // 删除旧的文档切片
+ knowledgeSegmentService.deleteKnowledgeSegmentByDocumentId(reqVO.getId());
+ // 重新创建文档切片
+ knowledgeSegmentService.createKnowledgeSegmentBySplitContentAsync(reqVO.getId(), oldDocument.getContent());
+ }
}
- /**
- * 校验文档是否存在
- *
- * @param id 文档编号
- * @return 文档信息
- */
- private AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
- AiKnowledgeDocumentDO knowledgeDocument = documentMapper.selectById(id);
+ @Override
+ public void updateKnowledgeDocumentStatus(AiKnowledgeDocumentUpdateStatusReqVO reqVO) {
+ // 1. 校验存在
+ AiKnowledgeDocumentDO document = validateKnowledgeDocumentExists(reqVO.getId());
+
+ // 2. 更新状态
+ knowledgeDocumentMapper.updateById(new AiKnowledgeDocumentDO()
+ .setId(reqVO.getId()).setStatus(reqVO.getStatus()));
+
+ // 3. 处理文档切片
+ if (CommonStatusEnum.isEnable(reqVO.getStatus())) {
+ knowledgeSegmentService.createKnowledgeSegmentBySplitContentAsync(reqVO.getId(), document.getContent());
+ } else {
+ knowledgeSegmentService.deleteKnowledgeSegmentByDocumentId(reqVO.getId());
+ }
+ }
+
+ @Override
+ @Transactional(rollbackFor = Exception.class)
+ public void deleteKnowledgeDocument(Long id) {
+ // 1. 校验存在
+ validateKnowledgeDocumentExists(id);
+
+ // 2. 删除
+ knowledgeDocumentMapper.deleteById(id);
+
+ // 3. 删除对应的段落
+ knowledgeSegmentService.deleteKnowledgeSegmentByDocumentId(id);
+ }
+
+ @Override
+ public void updateKnowledgeDocumentRetrievalCountIncr(Collection ids) {
+ if (CollUtil.isEmpty(ids)) {
+ return;
+ }
+ knowledgeDocumentMapper.updateRetrievalCountIncr(ids);
+ }
+
+ @Override
+ public AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
+ AiKnowledgeDocumentDO knowledgeDocument = knowledgeDocumentMapper.selectById(id);
if (knowledgeDocument == null) {
throw exception(KNOWLEDGE_DOCUMENT_NOT_EXISTS);
}
return knowledgeDocument;
}
- private org.springframework.core.io.Resource downloadFile(String url) {
+ @Override
+ public String readUrl(String url) {
+ // 下载文件
+ ByteArrayResource resource;
try {
byte[] bytes = HttpUtil.downloadBytes(url);
- return new ByteArrayResource(bytes);
+ if (bytes.length == 0) {
+ throw exception(KNOWLEDGE_DOCUMENT_FILE_EMPTY);
+ }
+ resource = new ByteArrayResource(bytes);
} catch (Exception e) {
- log.error("[downloadFile][url({}) 下载失败]", url, e);
- throw new RuntimeException(e);
+ log.error("[readUrl][url({}) 读取失败]", url, e);
+ throw exception(KNOWLEDGE_DOCUMENT_FILE_DOWNLOAD_FAIL);
}
+
+ // 读取文件
+ TikaDocumentReader loader = new TikaDocumentReader(resource);
+ List documents = loader.get();
+ Document document = CollUtil.getFirst(documents);
+ if (document == null || StrUtil.isEmpty(document.getText())) {
+ throw exception(KNOWLEDGE_DOCUMENT_FILE_READ_FAIL);
+ }
+ return document.getText();
+ }
+
+ @Override
+ public List getKnowledgeDocumentList(Collection ids) {
+ if (CollUtil.isEmpty(ids)) {
+ return Collections.emptyList();
+ }
+ return knowledgeDocumentMapper.selectBatchIds(ids);
}
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
index 91bffc2761..54f7217055 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentService.java
@@ -2,12 +2,19 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import org.springframework.scheduling.annotation.Async;
+import java.util.Collection;
import java.util.List;
+import java.util.Map;
+
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
/**
* AI 知识库段落 Service 接口
@@ -16,6 +23,32 @@ import java.util.List;
*/
public interface AiKnowledgeSegmentService {
+ /**
+ * 获取知识库段落详情
+ *
+ * @param id 段落编号
+ * @return 段落详情
+ */
+ AiKnowledgeSegmentDO getKnowledgeSegment(Long id);
+
+ /**
+ * 获取知识库段落列表
+ *
+ * @param ids 段落编号列表
+ * @return 段落列表
+ */
+ List getKnowledgeSegmentList(Collection ids);
+
+ /**
+ * 获取知识库段落 Map
+ *
+ * @param ids 段落编号列表
+ * @return 段落 Map
+ */
+ default Map getKnowledgeSegmentMap(Collection ids) {
+ return convertMap(getKnowledgeSegmentList(ids), AiKnowledgeSegmentDO::getId);
+ }
+
/**
* 获取段落分页
*
@@ -24,12 +57,39 @@ public interface AiKnowledgeSegmentService {
*/
PageResult getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO);
+ /**
+ * 基于 content 内容,切片创建多个段落
+ *
+ * @param documentId 知识库文档编号
+ * @param content 文档内容
+ */
+ void createKnowledgeSegmentBySplitContent(Long documentId, String content);
+
+ /**
+ * 【异步】基于 content 内容,切片创建多个段落
+ *
+ * @param documentId 知识库文档编号
+ * @param content 文档内容
+ */
+ @Async
+ default void createKnowledgeSegmentBySplitContentAsync(Long documentId, String content) {
+ createKnowledgeSegmentBySplitContent(documentId, content);
+ }
+
+ /**
+ * 创建知识库段落
+ *
+ * @param createReqVO 创建信息
+ * @return 段落编号
+ */
+ Long createKnowledgeSegment(AiKnowledgeSegmentSaveReqVO createReqVO);
+
/**
* 更新段落的内容
*
* @param reqVO 更新内容
*/
- void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO);
+ void updateKnowledgeSegment(AiKnowledgeSegmentSaveReqVO reqVO);
/**
* 更新段落的状态
@@ -39,11 +99,52 @@ public interface AiKnowledgeSegmentService {
void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO);
/**
- * 召回段落
+ * 重新索引知识库下的所有文档段落
*
- * @param reqVO 召回请求信息
- * @return 召回的段落
+ * @param knowledgeId 知识库编号
*/
- List similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO);
+ void reindexKnowledgeSegmentByKnowledgeId(Long knowledgeId);
+
+ /**
+ * 【异步】重新索引知识库下的所有文档段落
+ *
+ * @param knowledgeId 知识库编号
+ */
+ @Async
+ default void reindexByKnowledgeIdAsync(Long knowledgeId) {
+ reindexKnowledgeSegmentByKnowledgeId(knowledgeId);
+ }
+
+ /**
+ * 根据文档编号删除段落
+ *
+ * @param documentId 文档编号
+ */
+ void deleteKnowledgeSegmentByDocumentId(Long documentId);
+
+ /**
+ * 搜索知识库段落,并返回结果
+ *
+ * @param reqBO 搜索请求信息
+ * @return 搜索结果段落列表
+ */
+ List searchKnowledgeSegment(AiKnowledgeSegmentSearchReqBO reqBO);
+
+ /**
+ * 根据 URL 内容,切片创建多个段落
+ *
+ * @param url URL 地址
+ * @param segmentMaxTokens 段落最大 Token 数
+ * @return 切片后的段落列表
+ */
+ List splitContent(String url, Integer segmentMaxTokens);
+
+ /**
+ * 获取文档处理进度(多个)
+ *
+ * @param documentIds 文档编号列表
+ * @return 文档处理列表
+ */
+ List getKnowledgeSegmentProcessList(List documentIds);
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
index 5523fe2783..20f881cf13 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeSegmentServiceImpl.java
@@ -2,31 +2,39 @@ package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
+import cn.hutool.core.util.ObjUtil;
+import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSearchReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentProcessRespVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentSaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.segment.AiKnowledgeSegmentUpdateStatusReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeSegmentDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeSegmentMapper;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchReqBO;
+import cn.iocoder.yudao.module.ai.service.knowledge.bo.AiKnowledgeSegmentSearchRespBO;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
+import org.springframework.ai.tokenizer.TokenCountEstimator;
+import org.springframework.ai.transformer.splitter.TextSplitter;
+import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
+import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
-import java.util.List;
-import java.util.Objects;
+import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGMENT_NOT_EXISTS;
/**
@@ -38,15 +46,28 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_SEGM
@Slf4j
public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService {
+ private static final String VECTOR_STORE_METADATA_KNOWLEDGE_ID = "knowledgeId";
+ private static final String VECTOR_STORE_METADATA_DOCUMENT_ID = "documentId";
+ private static final String VECTOR_STORE_METADATA_SEGMENT_ID = "segmentId";
+
+ private static final Map> VECTOR_STORE_METADATA_TYPES = Map.of(
+ VECTOR_STORE_METADATA_KNOWLEDGE_ID, String.class,
+ VECTOR_STORE_METADATA_DOCUMENT_ID, String.class,
+ VECTOR_STORE_METADATA_SEGMENT_ID, String.class);
+
@Resource
private AiKnowledgeSegmentMapper segmentMapper;
@Resource
private AiKnowledgeService knowledgeService;
@Resource
- private AiChatModelService chatModelService;
+ @Lazy // 延迟加载,避免循环依赖
+ private AiKnowledgeDocumentService knowledgeDocumentService;
@Resource
- private AiApiKeyService apiKeyService;
+ private AiModelService modelService;
+
+ @Resource
+ private TokenCountEstimator tokenCountEstimator;
@Override
public PageResult getKnowledgeSegmentPage(AiKnowledgeSegmentPageReqVO pageReqVO) {
@@ -54,67 +75,198 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
}
@Override
- public void updateKnowledgeSegment(AiKnowledgeSegmentUpdateReqVO reqVO) {
+ public void createKnowledgeSegmentBySplitContent(Long documentId, String content) {
// 1. 校验
- AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
+ AiKnowledgeDocumentDO documentDO = knowledgeDocumentService.validateKnowledgeDocumentExists(documentId);
+ AiKnowledgeDO knowledgeDO = knowledgeService.validateKnowledgeExists(documentDO.getKnowledgeId());
+ VectorStore vectorStore = getVectorStoreById(knowledgeDO);
- // 2.1 获取知识库向量实例
- VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
- // 2.2 删除原向量
- vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
- // 2.3 重新向量化
- Document document = new Document(reqVO.getContent());
- document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
- vectorStore.add(List.of(document));
+ // 2. 文档切片
+ List documentSegments = splitContentByToken(content, documentDO.getSegmentMaxTokens());
- // 3. 更新段落内容
- AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
- knowledgeSegment.setVectorId(document.getId());
- segmentMapper.updateById(knowledgeSegment);
+ // 3.1 存储切片
+ List segmentDOs = convertList(documentSegments, segment -> {
+ if (StrUtil.isEmpty(segment.getText())) {
+ return null;
+ }
+ return new AiKnowledgeSegmentDO().setKnowledgeId(documentDO.getKnowledgeId()).setDocumentId(documentId)
+ .setContent(segment.getText()).setContentLength(segment.getText().length())
+ .setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY)
+ .setTokens(tokenCountEstimator.estimate(segment.getText()))
+ .setStatus(CommonStatusEnum.ENABLE.getStatus());
+ });
+ segmentMapper.insertBatch(segmentDOs);
+ // 3.2 切片向量化
+ for (int i = 0; i < documentSegments.size(); i++) {
+ Document segment = documentSegments.get(i);
+ AiKnowledgeSegmentDO segmentDO = segmentDOs.get(i);
+ writeVectorStore(vectorStore, segmentDO, segment);
+ }
+ }
+
+ @Override
+ public void updateKnowledgeSegment(AiKnowledgeSegmentSaveReqVO reqVO) {
+ // 1. 校验
+ AiKnowledgeSegmentDO oldSegment = validateKnowledgeSegmentExists(reqVO.getId());
+
+ // 2. 删除向量
+ VectorStore vectorStore = getVectorStoreById(oldSegment.getKnowledgeId());
+ deleteVectorStore(vectorStore, oldSegment);
+
+ // 3.1 更新切片
+ AiKnowledgeSegmentDO newSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
+ segmentMapper.updateById(newSegment);
+ // 3.2 重新向量化,必须开启状态
+ if (CommonStatusEnum.isEnable(oldSegment.getStatus())) {
+ newSegment.setKnowledgeId(oldSegment.getKnowledgeId()).setDocumentId(oldSegment.getDocumentId());
+ writeVectorStore(vectorStore, newSegment, new Document(newSegment.getContent()));
+ }
+ }
+
+ @Override
+ public void deleteKnowledgeSegmentByDocumentId(Long documentId) {
+ // 1. 查询需要删除的段落
+ List segments = segmentMapper.selectListByDocumentId(documentId);
+ if (CollUtil.isEmpty(segments)) {
+ return;
+ }
+
+ // 2. 批量删除段落记录
+ segmentMapper.deleteByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
+
+ // 3. 删除向量存储中的段落
+ VectorStore vectorStore = getVectorStoreById(segments.get(0).getKnowledgeId());
+ vectorStore.delete(convertList(segments, AiKnowledgeSegmentDO::getVectorId));
}
@Override
public void updateKnowledgeSegmentStatus(AiKnowledgeSegmentUpdateStatusReqVO reqVO) {
- // 0 校验
- AiKnowledgeSegmentDO oldKnowledgeSegment = validateKnowledgeSegmentExists(reqVO.getId());
- // 1 获取知识库向量实例
- VectorStore vectorStore = knowledgeService.getVectorStoreById(oldKnowledgeSegment.getKnowledgeId());
- AiKnowledgeSegmentDO knowledgeSegment = BeanUtils.toBean(reqVO, AiKnowledgeSegmentDO.class);
+ // 1. 校验
+ AiKnowledgeSegmentDO segment = validateKnowledgeSegmentExists(reqVO.getId());
- if (Objects.equals(reqVO.getStatus(), CommonStatusEnum.ENABLE.getStatus())) {
- // 2.1 启用重新向量化
- Document document = new Document(oldKnowledgeSegment.getContent());
- document.getMetadata().put(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, oldKnowledgeSegment.getKnowledgeId());
- vectorStore.add(List.of(document));
- knowledgeSegment.setVectorId(document.getId());
+ // 2. 获取知识库向量实例
+ VectorStore vectorStore = getVectorStoreById(segment.getKnowledgeId());
+
+ // 3. 更新状态
+ segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(reqVO.getId()).setStatus(reqVO.getStatus()));
+
+ // 4. 更新向量
+ if (CommonStatusEnum.isEnable(reqVO.getStatus())) {
+ writeVectorStore(vectorStore, segment, new Document(segment.getContent()));
} else {
- // 2.2 禁用删除向量
- vectorStore.delete(List.of(oldKnowledgeSegment.getVectorId()));
- knowledgeSegment.setVectorId("");
+ deleteVectorStore(vectorStore, segment);
}
- // 3 更新段落状态
- segmentMapper.updateById(knowledgeSegment);
}
@Override
- public List similaritySearch(AiKnowledgeSegmentSearchReqVO reqVO) {
+ public void reindexKnowledgeSegmentByKnowledgeId(Long knowledgeId) {
+ // 1.1 校验知识库存在
+ AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
+ // 1.2 获取知识库向量实例
+ VectorStore vectorStore = getVectorStoreById(knowledge);
+
+ // 2.1 查询知识库下的所有启用状态的段落
+ List segments = segmentMapper.selectListByKnowledgeIdAndStatus(
+ knowledgeId, CommonStatusEnum.ENABLE.getStatus());
+ if (CollUtil.isEmpty(segments)) {
+ return;
+ }
+ // 2.2 遍历所有段落,重新索引
+ for (AiKnowledgeSegmentDO segment : segments) {
+ // 删除旧的向量
+ deleteVectorStore(vectorStore, segment);
+ // 重新创建向量
+ writeVectorStore(vectorStore, segment, new Document(segment.getContent()));
+ }
+ log.info("[reindexKnowledgeSegmentByKnowledgeId][知识库({}) 重新索引完成,共处理 {} 个段落]",
+ knowledgeId, segments.size());
+ }
+
+ private void writeVectorStore(VectorStore vectorStore, AiKnowledgeSegmentDO segmentDO, Document segment) {
+ // 1. 向量存储
+ // 为什么要 toString 呢?因为部分 VectorStore 实现,不支持 Long 类型,例如说 QdrantVectorStore
+ segment.getMetadata().put(VECTOR_STORE_METADATA_KNOWLEDGE_ID, segmentDO.getKnowledgeId().toString());
+ segment.getMetadata().put(VECTOR_STORE_METADATA_DOCUMENT_ID, segmentDO.getDocumentId().toString());
+ segment.getMetadata().put(VECTOR_STORE_METADATA_SEGMENT_ID, segmentDO.getId().toString());
+ vectorStore.add(List.of(segment));
+
+ // 2. 更新向量 ID
+ segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(segmentDO.getId()).setVectorId(segment.getId()));
+ }
+
+ private void deleteVectorStore(VectorStore vectorStore, AiKnowledgeSegmentDO segmentDO) {
+ // 1. 更新向量 ID
+ if (StrUtil.isEmpty(segmentDO.getVectorId())) {
+ return;
+ }
+ segmentMapper.updateById(new AiKnowledgeSegmentDO().setId(segmentDO.getId())
+ .setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY));
+
+ // 2. 删除向量
+ vectorStore.delete(List.of(segmentDO.getVectorId()));
+ }
+
+ @Override
+ public List searchKnowledgeSegment(AiKnowledgeSegmentSearchReqBO reqBO) {
// 1. 校验
- AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqVO.getKnowledgeId());
- AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
+ AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(reqBO.getKnowledgeId());
- // 2. 获取向量存储实例
- VectorStore vectorStore = apiKeyService.getOrCreateVectorStore(model.getKeyId());
-
- // 3.1 向量检索
- List documentList = vectorStore.similaritySearch(SearchRequest.query(reqVO.getContent())
- .withTopK(knowledge.getTopK())
- .withSimilarityThreshold(knowledge.getSimilarityThreshold())
- .withFilterExpression(new FilterExpressionBuilder().eq(AiKnowledgeSegmentDO.FIELD_KNOWLEDGE_ID, reqVO.getKnowledgeId()).build()));
- if (CollUtil.isEmpty(documentList)) {
+ // 2.1 向量检索
+ VectorStore vectorStore = getVectorStoreById(knowledge);
+ List documents = vectorStore.similaritySearch(SearchRequest.builder()
+ .query(reqBO.getContent())
+ .topK(ObjUtil.defaultIfNull(reqBO.getTopK(), knowledge.getTopK()))
+ .similarityThreshold(
+ ObjUtil.defaultIfNull(reqBO.getSimilarityThreshold(), knowledge.getSimilarityThreshold()))
+ .filterExpression(new FilterExpressionBuilder()
+ .eq(VECTOR_STORE_METADATA_KNOWLEDGE_ID, reqBO.getKnowledgeId().toString())
+ .build())
+ .build());
+ if (CollUtil.isEmpty(documents)) {
return ListUtil.empty();
}
- // 3.2 段落召回
- return segmentMapper.selectListByVectorIds(CollUtil.getFieldValues(documentList, "id", String.class));
+ // 2.2 段落召回
+ List segments = segmentMapper
+ .selectListByVectorIds(convertList(documents, Document::getId));
+ if (CollUtil.isEmpty(segments)) {
+ return ListUtil.empty();
+ }
+
+ // 3. 增加召回次数
+ segmentMapper.updateRetrievalCountIncrByIds(convertList(segments, AiKnowledgeSegmentDO::getId));
+
+ // 4. 构建结果
+ List result = convertList(segments, segment -> {
+ Document document = CollUtil.findOne(documents, // 找到对应的文档
+ doc -> Objects.equals(doc.getId(), segment.getVectorId()));
+ if (document == null) {
+ return null;
+ }
+ return BeanUtils.toBean(segment, AiKnowledgeSegmentSearchRespBO.class)
+ .setScore(document.getScore());
+ });
+ result.sort((o1, o2) -> Double.compare(o2.getScore(), o1.getScore())); // 按照分数降序排序
+ return result;
+ }
+
+ @Override
+ public List splitContent(String url, Integer segmentMaxTokens) {
+ // 1. 读取 URL 内容
+ String content = knowledgeDocumentService.readUrl(url);
+
+ // 2. 文档切片
+ List documentSegments = splitContentByToken(content, segmentMaxTokens);
+
+ // 3. 转换为段落对象
+ return convertList(documentSegments, segment -> {
+ if (StrUtil.isEmpty(segment.getText())) {
+ return null;
+ }
+ return new AiKnowledgeSegmentDO()
+ .setContent(segment.getText())
+ .setContentLength(segment.getText().length())
+ .setTokens(tokenCountEstimator.estimate(segment.getText()));
+ });
}
/**
@@ -131,4 +283,75 @@ public class AiKnowledgeSegmentServiceImpl implements AiKnowledgeSegmentService
return knowledgeSegment;
}
+ private VectorStore getVectorStoreById(AiKnowledgeDO knowledge) {
+ return modelService.getOrCreateVectorStore(knowledge.getEmbeddingModelId(), VECTOR_STORE_METADATA_TYPES);
+ }
+
+ private VectorStore getVectorStoreById(Long knowledgeId) {
+ AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(knowledgeId);
+ return getVectorStoreById(knowledge);
+ }
+
+ private static List splitContentByToken(String content, Integer segmentMaxTokens) {
+ TextSplitter textSplitter = buildTokenTextSplitter(segmentMaxTokens);
+ return textSplitter.apply(Collections.singletonList(new Document(content)));
+ }
+
+ private static TextSplitter buildTokenTextSplitter(Integer segmentMaxTokens) {
+ return TokenTextSplitter.builder()
+ .withChunkSize(segmentMaxTokens)
+ .withMinChunkSizeChars(Integer.MAX_VALUE) // 忽略字符的截断
+ .withMinChunkLengthToEmbed(1) // 允许的最小有效分段长度
+ .withMaxNumChunks(Integer.MAX_VALUE)
+ .withKeepSeparator(true) // 保留分隔符
+ .build();
+ }
+
+ @Override
+ public List getKnowledgeSegmentProcessList(List documentIds) {
+ if (CollUtil.isEmpty(documentIds)) {
+ return Collections.emptyList();
+ }
+ return segmentMapper.selectProcessList(documentIds);
+ }
+
+ @Override
+ public Long createKnowledgeSegment(AiKnowledgeSegmentSaveReqVO createReqVO) {
+ // 1.1 校验文档是否存在
+ AiKnowledgeDocumentDO document = knowledgeDocumentService
+ .validateKnowledgeDocumentExists(createReqVO.getDocumentId());
+ // 1.2 获取知识库信息
+ AiKnowledgeDO knowledge = knowledgeService.validateKnowledgeExists(document.getKnowledgeId());
+ // 1.3 校验 token 熟练
+ Integer tokens = tokenCountEstimator.estimate(createReqVO.getContent());
+ if (tokens > document.getSegmentMaxTokens()) {
+ throw exception(KNOWLEDGE_SEGMENT_CONTENT_TOO_LONG, tokens, document.getSegmentMaxTokens());
+ }
+
+ // 2. 保存段落
+ AiKnowledgeSegmentDO segment = BeanUtils.toBean(createReqVO, AiKnowledgeSegmentDO.class)
+ .setKnowledgeId(knowledge.getId()).setDocumentId(document.getId())
+ .setContentLength(createReqVO.getContent().length()).setTokens(tokens)
+ .setVectorId(AiKnowledgeSegmentDO.VECTOR_ID_EMPTY)
+ .setRetrievalCount(0).setStatus(CommonStatusEnum.ENABLE.getStatus());
+ segmentMapper.insert(segment);
+
+ // 3. 向量化
+ writeVectorStore(getVectorStoreById(knowledge), segment, new Document(segment.getContent()));
+ return segment.getId();
+ }
+
+ @Override
+ public AiKnowledgeSegmentDO getKnowledgeSegment(Long id) {
+ return segmentMapper.selectById(id);
+ }
+
+ @Override
+ public List getKnowledgeSegmentList(Collection ids) {
+ if (CollUtil.isEmpty(ids)) {
+ return Collections.emptyList();
+ }
+ return segmentMapper.selectBatchIds(ids);
+ }
+
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
index 7060076a42..5336570d27 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java
@@ -1,11 +1,11 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import org.springframework.ai.vectorstore.VectorStore;
+
+import java.util.List;
/**
* AI 知识库-基础信息 Service 接口
@@ -18,18 +18,24 @@ public interface AiKnowledgeService {
* 创建知识库
*
* @param createReqVO 创建信息
- * @param userId 用户编号
* @return 编号
*/
- Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId);
+ Long createKnowledge(AiKnowledgeSaveReqVO createReqVO);
/**
* 更新知识库
*
* @param updateReqVO 更新信息
- * @param userId 用户编号
*/
- void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId);
+ void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
+
+ /**
+ * 获得知识库
+ *
+ * @param id 编号
+ * @return 知识库
+ */
+ AiKnowledgeDO getKnowledge(Long id);
/**
* 校验知识库是否存在
@@ -41,18 +47,17 @@ public interface AiKnowledgeService {
/**
* 获得知识库分页
*
- * @param userId 用户编号
* @param pageReqVO 分页查询
* @return 知识库分页
*/
- PageResult getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO);
+ PageResult getKnowledgePage(AiKnowledgePageReqVO pageReqVO);
/**
- * 根据知识库编号获取向量存储实例
+ * 获得指定状态的知识库列表
*
- * @param id 知识库编号
- * @return 向量存储实例
+ * @param status 状态
+ * @return 知识库列表
*/
- VectorStore getVectorStoreById(Long id);
+ List getKnowledgeSimpleListByStatus(Integer status);
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
index 1a000c19d1..59afd7d7bc 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java
@@ -4,19 +4,19 @@ import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeCreateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeUpdateReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
-import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
+import java.util.List;
+
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.KNOWLEDGE_NOT_EXISTS;
@@ -33,36 +33,43 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiKnowledgeMapper knowledgeMapper;
@Resource
- private AiChatModelService chatModelService;
+ private AiModelService modelService;
@Resource
- private AiApiKeyService apiKeyService;
+ private AiKnowledgeSegmentService knowledgeSegmentService;
@Override
- public Long createKnowledge(AiKnowledgeCreateReqVO createReqVO, Long userId) {
+ public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
// 1. 校验模型配置
- AiChatModelDO model = chatModelService.validateChatModel(createReqVO.getModelId());
+ AiModelDO model = modelService.validateModel(createReqVO.getEmbeddingModelId());
// 2. 插入知识库
- AiKnowledgeDO knowledgeBase = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
- .setModel(model.getModel()).setUserId(userId).setStatus(CommonStatusEnum.ENABLE.getStatus());
- knowledgeMapper.insert(knowledgeBase);
- return knowledgeBase.getId();
+ AiKnowledgeDO knowledge = BeanUtils.toBean(createReqVO, AiKnowledgeDO.class)
+ .setEmbeddingModel(model.getModel());
+ knowledgeMapper.insert(knowledge);
+ return knowledge.getId();
}
@Override
- public void updateKnowledge(AiKnowledgeUpdateReqVO updateReqVO, Long userId) {
+ public void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO) {
// 1.1 校验知识库存在
- AiKnowledgeDO knowledgeBaseDO = validateKnowledgeExists(updateReqVO.getId());
- if (ObjUtil.notEqual(knowledgeBaseDO.getUserId(), userId)) {
- throw exception(KNOWLEDGE_NOT_EXISTS);
- }
+ AiKnowledgeDO oldKnowledge = validateKnowledgeExists(updateReqVO.getId());
// 1.2 校验模型配置
- AiChatModelDO model = chatModelService.validateChatModel(updateReqVO.getModelId());
+ AiModelDO model = modelService.validateModel(updateReqVO.getEmbeddingModelId());
// 2. 更新知识库
- AiKnowledgeDO updateDO = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class);
- updateDO.setModel(model.getModel());
- knowledgeMapper.updateById(updateDO);
+ AiKnowledgeDO updateObj = BeanUtils.toBean(updateReqVO, AiKnowledgeDO.class)
+ .setEmbeddingModel(model.getModel());
+ knowledgeMapper.updateById(updateObj);
+
+ // 3. 如果模型变化,需要 reindex 所有的文档
+ if (ObjUtil.notEqual(oldKnowledge.getEmbeddingModelId(), updateReqVO.getEmbeddingModelId())) {
+ knowledgeSegmentService.reindexByKnowledgeIdAsync(updateReqVO.getId());
+ }
+ }
+
+ @Override
+ public AiKnowledgeDO getKnowledge(Long id) {
+ return knowledgeMapper.selectById(id);
}
@Override
@@ -75,16 +82,13 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
}
@Override
- public PageResult getKnowledgePage(Long userId, AiKnowledgePageReqVO pageReqVO) {
- return knowledgeMapper.selectPage(userId, pageReqVO);
+ public PageResult getKnowledgePage(AiKnowledgePageReqVO pageReqVO) {
+ return knowledgeMapper.selectPage(pageReqVO);
}
@Override
- public VectorStore getVectorStoreById(Long id) {
- AiKnowledgeDO knowledge = validateKnowledgeExists(id);
- AiChatModelDO model = chatModelService.validateChatModel(knowledge.getModelId());
- // 创建或获取 VectorStore 对象
- return apiKeyService.getOrCreateVectorStore(model.getKeyId());
+ public List getKnowledgeSimpleListByStatus(Integer status) {
+ return knowledgeMapper.selectListByStatus(status);
}
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchReqBO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchReqBO.java
new file mode 100644
index 0000000000..9ff63b6460
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchReqBO.java
@@ -0,0 +1,39 @@
+package cn.iocoder.yudao.module.ai.service.knowledge.bo;
+
+import lombok.Data;
+
+import javax.validation.constraints.NotNull;
+
+import jakarta.validation.constraints.NotEmpty;
+
+/**
+ * AI 知识库段落搜索 Request BO
+ *
+ * @author 芋道源码
+ */
+@Data
+public class AiKnowledgeSegmentSearchReqBO {
+
+ /**
+ * 知识库编号
+ */
+ @NotNull(message = "知识库编号不能为空")
+ private Long knowledgeId;
+
+ /**
+ * 内容
+ */
+ @NotEmpty(message = "内容不能为空")
+ private String content;
+
+ /**
+ * 最大返回数量
+ */
+ private Integer topK;
+
+ /**
+ * 相似度阈值
+ */
+ private Double similarityThreshold;
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchRespBO.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchRespBO.java
new file mode 100644
index 0000000000..72eb84624a
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/bo/AiKnowledgeSegmentSearchRespBO.java
@@ -0,0 +1,45 @@
+package cn.iocoder.yudao.module.ai.service.knowledge.bo;
+
+import lombok.Data;
+
+/**
+ * AI 知识库段落搜索 Response BO
+ *
+ * @author 芋道源码
+ */
+@Data
+public class AiKnowledgeSegmentSearchRespBO {
+
+ /**
+ * 段落编号
+ */
+ private Long id;
+ /**
+ * 文档编号
+ */
+ private Long documentId;
+ /**
+ * 知识库编号
+ */
+ private Long knowledgeId;
+
+ /**
+ * 内容
+ */
+ private String content;
+ /**
+ * 内容长度
+ */
+ private Integer contentLength;
+
+ /**
+ * Token 数量
+ */
+ private Integer tokens;
+
+ /**
+ * 相似度分数
+ */
+ private Double score;
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
index b34bd63481..0dc851c216 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/mindmap/AiMindMapServiceImpl.java
@@ -1,8 +1,9 @@
package cn.iocoder.yudao.module.ai.service.mindmap;
import cn.hutool.core.collection.CollUtil;
-import cn.hutool.core.lang.Assert;
+import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
+import cn.iocoder.yudao.framework.ai.core.enums.AiModelTypeEnum;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
@@ -12,14 +13,13 @@ import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapGenerateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.mindmap.vo.AiMindMapPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.mindmap.AiMindMapDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.mindmap.AiMindMapMapper;
import cn.iocoder.yudao.module.ai.enums.AiChatRoleEnum;
import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
-import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.Message;
@@ -38,7 +38,7 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_EXISTS;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
/**
* AI 思维导图 Service 实现类
@@ -50,9 +50,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MIND_MAP_NOT_E
public class AiMindMapServiceImpl implements AiMindMapService {
@Resource
- private AiApiKeyService apiKeyService;
- @Resource
- private AiChatModelService chatModalService;
+ private AiModelService modalService;
@Resource
private AiChatRoleService chatRoleService;
@@ -65,17 +63,17 @@ public class AiMindMapServiceImpl implements AiMindMapService {
AiChatRoleDO role = CollUtil.getFirst(
chatRoleService.getChatRoleListByName(AiChatRoleEnum.AI_MIND_MAP_ROLE.getName()));
// 1.1 获取导图执行模型
- AiChatModelDO model = getModel(role);
+ AiModelDO model = getModel(role);
// 1.2 获取角色设定消息
String systemMessage = role != null && StrUtil.isNotBlank(role.getSystemMessage())
? role.getSystemMessage() : AiChatRoleEnum.AI_MIND_MAP_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
- ChatModel chatModel = apiKeyService.getChatModel(model.getKeyId());
+ ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 插入思维导图信息
- AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class,
- mindMap -> mindMap.setUserId(userId).setModel(model.getModel()).setPlatform(platform.getPlatform()));
+ AiMindMapDO mindMapDO = BeanUtils.toBean(generateReqVO, AiMindMapDO.class, mindMap -> mindMap.setUserId(userId)
+ .setPlatform(platform.getPlatform()).setModelId(model.getId()).setModel(model.getModel()));
mindMapMapper.insert(mindMapDO);
// 3.1 构建 Prompt,并进行调用
@@ -85,7 +83,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
// 3.2 流式返回
StringBuffer contentBuffer = new StringBuffer();
return streamResponse.map(chunk -> {
- String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getContent() : null;
+ String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
newContent = StrUtil.nullToDefault(newContent, ""); // 避免 null 的 情况
contentBuffer.append(newContent);
// 响应结果
@@ -103,7 +101,7 @@ public class AiMindMapServiceImpl implements AiMindMapService {
}
- private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiChatModelDO model, String systemMessage) {
+ private Prompt buildPrompt(AiMindMapGenerateReqVO generateReqVO, AiModelDO model, String systemMessage) {
// 1. 构建 message 列表
List chatMessages = buildMessages(generateReqVO, systemMessage);
// 2. 构建 options 对象
@@ -123,15 +121,21 @@ public class AiMindMapServiceImpl implements AiMindMapService {
return chatMessages;
}
- private AiChatModelDO getModel(AiChatRoleDO role) {
- AiChatModelDO model = null;
+ private AiModelDO getModel(AiChatRoleDO role) {
+ AiModelDO model = null;
if (role != null && role.getModelId() != null) {
- model = chatModalService.getChatModel(role.getModelId());
+ model = modalService.getModel(role.getModelId());
}
if (model == null) {
- model = chatModalService.getRequiredDefaultChatModel();
+ model = modalService.getRequiredDefaultModel(AiModelTypeEnum.CHAT.getType());
+ }
+ // 校验模型存在、且合法
+ if (model == null) {
+ throw exception(MODEL_NOT_EXISTS);
+ }
+ if (ObjUtil.notEqual(model.getType(), AiModelTypeEnum.CHAT.getType())) {
+ throw exception(MODEL_USE_TYPE_ERROR);
}
- Assert.notNull(model, "[AI] 获取不到模型");
return model;
}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java
index f5f8813492..44da80041c 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyService.java
@@ -1,17 +1,10 @@
package cn.iocoder.yudao.module.ai.service.model;
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
-import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.image.ImageModel;
-import org.springframework.ai.vectorstore.VectorStore;
import java.util.List;
@@ -75,58 +68,13 @@ public interface AiApiKeyService {
*/
List getApiKeyList();
- // ========== 与 spring-ai 集成 ==========
-
/**
- * 获得 ChatModel 对象
- *
- * @param id 编号
- * @return ChatModel 对象
- */
- ChatModel getChatModel(Long id);
-
- /**
- * 获得 ImageModel 对象
- *
- * TODO 可优化点:目前默认获取 platform 对应的第一个开启的配置用于绘画;后续可以支持配置选择
+ * 获得默认的 API 密钥
*
* @param platform 平台
- * @return ImageModel 对象
+ * @param status 状态
+ * @return API 密钥
*/
- ImageModel getImageModel(AiPlatformEnum platform);
-
- /**
- * 获得 MidjourneyApi 对象
- *
- * TODO 可优化点:目前默认获取 Midjourney 对应的第一个开启的配置用于绘画;后续可以支持配置选择
- *
- * @return MidjourneyApi 对象
- */
- MidjourneyApi getMidjourneyApi();
-
- /**
- * 获得 SunoApi 对象
- *
- * TODO 可优化点:目前默认获取 Suno 对应的第一个开启的配置用于音乐;后续可以支持配置选择
- *
- * @return SunoApi 对象
- */
- SunoApi getSunoApi();
-
- /**
- * 获得 EmbeddingModel 对象
- *
- * @param id 编号
- * @return EmbeddingModel 对象
- */
- EmbeddingModel getEmbeddingModel(Long id);
-
- /**
- * 获得 VectorStore 对象
- *
- * @param id 编号
- * @return VectorStore 对象
- */
- VectorStore getOrCreateVectorStore(Long id);
+ AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status);
}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
index 50e1fbd7ac..f0bac8a6d9 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiApiKeyServiceImpl.java
@@ -1,9 +1,5 @@
package cn.iocoder.yudao.module.ai.service.model;
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
-import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
-import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@@ -12,17 +8,14 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.image.ImageModel;
-import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_DISABLE;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.API_KEY_NOT_EXISTS;
/**
* AI API 密钥 Service 实现类
@@ -36,9 +29,6 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
@Resource
private AiApiKeyMapper apiKeyMapper;
- @Resource
- private AiModelFactory modelFactory;
-
@Override
public Long createApiKey(AiApiKeySaveReqVO createReqVO) {
// 插入
@@ -97,57 +87,13 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return apiKeyMapper.selectList();
}
- // ========== 与 spring-ai 集成 ==========
-
@Override
- public ChatModel getChatModel(Long id) {
- AiApiKeyDO apiKey = validateApiKey(id);
- AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
- return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
- }
-
- @Override
- public ImageModel getImageModel(AiPlatformEnum platform) {
- AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+ public AiApiKeyDO getRequiredDefaultApiKey(String platform, Integer status) {
+ AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform, status);
if (apiKey == null) {
- throw exception(API_KEY_IMAGE_NODE_FOUND, platform.getName());
+ throw exception(API_KEY_NOT_EXISTS);
}
- return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
- }
-
- @Override
- public MidjourneyApi getMidjourneyApi() {
- AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
- AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
- if (apiKey == null) {
- throw exception(API_KEY_MIDJOURNEY_NOT_FOUND);
- }
- return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
- }
-
- @Override
- public SunoApi getSunoApi() {
- AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
- AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
- if (apiKey == null) {
- throw exception(API_KEY_SUNO_NOT_FOUND);
- }
- return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
- }
-
- @Override
- public EmbeddingModel getEmbeddingModel(Long id) {
- AiApiKeyDO apiKey = validateApiKey(id);
- AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
- return modelFactory.getOrCreateEmbeddingModel(platform, apiKey.getApiKey(), apiKey.getUrl());
- }
-
- @Override
- public VectorStore getOrCreateVectorStore(Long id) {
- AiApiKeyDO apiKey = validateApiKey(id);
- AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
- // 创建或获取 VectorStore 对象
- return modelFactory.getOrCreateVectorStore(getEmbeddingModel(id), platform, apiKey.getApiKey(), apiKey.getUrl());
+ return apiKey;
}
}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java
deleted file mode 100644
index f83ac73c91..0000000000
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelService.java
+++ /dev/null
@@ -1,92 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.model;
-
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import jakarta.validation.Valid;
-
-import java.util.Collection;
-import java.util.List;
-
-import java.util.Set;
-
-/**
- * AI 聊天模型 Service 接口
- *
- * @author fansili
- * @since 2024/4/24 19:42
- */
-public interface AiChatModelService {
-
- /**
- * 创建聊天模型
- *
- * @param createReqVO 创建信息
- * @return 编号
- */
- Long createChatModel(@Valid AiChatModelSaveReqVO createReqVO);
-
- /**
- * 更新聊天模型
- *
- * @param updateReqVO 更新信息
- */
- void updateChatModel(@Valid AiChatModelSaveReqVO updateReqVO);
-
- /**
- * 删除聊天模型
- *
- * @param id 编号
- */
- void deleteChatModel(Long id);
-
- /**
- * 获得聊天模型
- *
- * @param id 编号
- * @return 聊天模型
- */
- AiChatModelDO getChatModel(Long id);
-
- /**
- * 获得默认的聊天模型
- *
- * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
- *
- * @return 聊天模型
- */
- AiChatModelDO getRequiredDefaultChatModel();
-
- /**
- * 获得聊天模型分页
- *
- * @param pageReqVO 分页查询
- * @return 聊天模型分页
- */
- PageResult getChatModelPage(AiChatModelPageReqVO pageReqVO);
-
- /**
- * 校验聊天模型
- *
- * @param id 编号
- * @return 聊天模型
- */
- AiChatModelDO validateChatModel(Long id);
-
- /**
- * 获得聊天模型列表
- *
- * @param status 状态
- * @return 聊天模型列表
- */
- List getChatModelListByStatus(Integer status);
-
- /**
- * 获得聊天模型列表
- *
- * @param ids 编号数组
- * @return 模型列表
- */
- List getChatModelList(Collection ids);
-}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java
deleted file mode 100644
index 4b11602f54..0000000000
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatModelServiceImpl.java
+++ /dev/null
@@ -1,114 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.model;
-
-import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
-import cn.iocoder.yudao.framework.common.pojo.PageResult;
-import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelPageReqVO;
-import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatModel.AiChatModelSaveReqVO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatModelDO;
-import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatModelMapper;
-import jakarta.annotation.Resource;
-import org.springframework.stereotype.Service;
-import org.springframework.validation.annotation.Validated;
-
-import java.util.Collection;
-import java.util.List;
-
-import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
-
-/**
- * AI 聊天模型 Service 实现类
- *
- * @author fansili
- */
-@Service
-@Validated
-public class AiChatModelServiceImpl implements AiChatModelService {
-
- @Resource
- private AiApiKeyService apiKeyService;
-
- @Resource
- private AiChatModelMapper chatModelMapper;
-
- @Override
- public Long createChatModel(AiChatModelSaveReqVO createReqVO) {
- // 1. 校验
- AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
- apiKeyService.validateApiKey(createReqVO.getKeyId());
-
- // 2. 插入
- AiChatModelDO chatModel = BeanUtils.toBean(createReqVO, AiChatModelDO.class);
- chatModelMapper.insert(chatModel);
- return chatModel.getId();
- }
-
- @Override
- public void updateChatModel(AiChatModelSaveReqVO updateReqVO) {
- // 1. 校验
- validateChatModelExists(updateReqVO.getId());
- AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
- apiKeyService.validateApiKey(updateReqVO.getKeyId());
-
- // 2. 更新
- AiChatModelDO updateObj = BeanUtils.toBean(updateReqVO, AiChatModelDO.class);
- chatModelMapper.updateById(updateObj);
- }
-
- @Override
- public void deleteChatModel(Long id) {
- // 校验存在
- validateChatModelExists(id);
- // 删除
- chatModelMapper.deleteById(id);
- }
-
- private AiChatModelDO validateChatModelExists(Long id) {
- AiChatModelDO model = chatModelMapper.selectById(id);
- if (chatModelMapper.selectById(id) == null) {
- throw exception(CHAT_MODEL_NOT_EXISTS);
- }
- return model;
- }
-
- @Override
- public AiChatModelDO getChatModel(Long id) {
- return chatModelMapper.selectById(id);
- }
-
- @Override
- public AiChatModelDO getRequiredDefaultChatModel() {
- AiChatModelDO model = chatModelMapper.selectFirstByStatus(CommonStatusEnum.ENABLE.getStatus());
- if (model == null) {
- throw exception(CHAT_MODEL_DEFAULT_NOT_EXISTS);
- }
- return model;
- }
-
- @Override
- public PageResult getChatModelPage(AiChatModelPageReqVO pageReqVO) {
- return chatModelMapper.selectPage(pageReqVO);
- }
-
- @Override
- public AiChatModelDO validateChatModel(Long id) {
- AiChatModelDO model = validateChatModelExists(id);
- if (CommonStatusEnum.isDisable(model.getStatus())) {
- throw exception(CHAT_MODEL_DISABLE);
- }
- return model;
- }
-
- @Override
- public List getChatModelListByStatus(Integer status) {
- return chatModelMapper.selectList(status);
- }
-
- @Override
- public List getChatModelList(Collection ids) {
- return chatModelMapper.selectBatchIds(ids);
- }
-
-}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java
index 81c8d259b6..e7fecf6ac7 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleService.java
@@ -32,7 +32,7 @@ public interface AiChatRoleService {
* 创建【我的】聊天角色
*
* @param createReqVO 创建信息
- * @param userId 用户编号
+ * @param userId 用户编号
* @return 编号
*/
Long createChatRoleMy(AiChatRoleSaveMyReqVO createReqVO, Long userId);
@@ -48,7 +48,7 @@ public interface AiChatRoleService {
* 创建【我的】聊天角色
*
* @param updateReqVO 更新信息
- * @param userId 用户编号
+ * @param userId 用户编号
*/
void updateChatRoleMy(AiChatRoleSaveMyReqVO updateReqVO, Long userId);
@@ -62,7 +62,7 @@ public interface AiChatRoleService {
/**
* 删除【我的】聊天角色
*
- * @param id 编号
+ * @param id 编号
* @param userId 用户编号
*/
void deleteChatRoleMy(Long id, Long userId);
@@ -106,7 +106,7 @@ public interface AiChatRoleService {
* 获得【我的】聊天角色分页
*
* @param pageReqVO 分页查询
- * @param userId 用户编号
+ * @param userId 用户编号
* @return 聊天角色分页
*/
PageResult getChatRoleMyPage(AiChatRolePageReqVO pageReqVO, Long userId);
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java
index 2cf4d46d1c..b0005c3af5 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiChatRoleServiceImpl.java
@@ -11,6 +11,7 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleS
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.chatRole.AiChatRoleSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiChatRoleDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatRoleMapper;
+import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -21,7 +22,8 @@ import java.util.List;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
-import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_ROLE_DISABLE;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_ROLE_NOT_EXISTS;
/**
* AI 聊天角色 Service 实现类
@@ -35,8 +37,19 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Resource
private AiChatRoleMapper chatRoleMapper;
+ @Resource
+ private AiKnowledgeService knowledgeService;
+ @Resource
+ private AiToolService toolService;
+
@Override
public Long createChatRole(AiChatRoleSaveReqVO createReqVO) {
+ // 校验文档
+ validateDocuments(createReqVO.getKnowledgeIds());
+ // 校验工具
+ validateTools(createReqVO.getToolIds());
+
+ // 保存角色
AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class);
chatRoleMapper.insert(chatRole);
return chatRole.getId();
@@ -44,6 +57,12 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override
public Long createChatRoleMy(AiChatRoleSaveMyReqVO createReqVO, Long userId) {
+ // 校验文档
+ validateDocuments(createReqVO.getKnowledgeIds());
+ // 校验工具
+ validateTools(createReqVO.getToolIds());
+
+ // 保存角色
AiChatRoleDO chatRole = BeanUtils.toBean(createReqVO, AiChatRoleDO.class).setUserId(userId)
.setStatus(CommonStatusEnum.ENABLE.getStatus()).setPublicStatus(false);
chatRoleMapper.insert(chatRole);
@@ -54,7 +73,12 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
public void updateChatRole(AiChatRoleSaveReqVO updateReqVO) {
// 校验存在
validateChatRoleExists(updateReqVO.getId());
- // 更新
+ // 校验文档
+ validateDocuments(updateReqVO.getKnowledgeIds());
+ // 校验工具
+ validateTools(updateReqVO.getToolIds());
+
+ // 更新角色
AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
chatRoleMapper.updateById(updateObj);
}
@@ -66,12 +90,42 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
if (ObjectUtil.notEqual(chatRole.getUserId(), userId)) {
throw exception(CHAT_ROLE_NOT_EXISTS);
}
+ // 校验文档
+ validateDocuments(updateReqVO.getKnowledgeIds());
+ // 校验工具
+ validateTools(updateReqVO.getToolIds());
// 更新
AiChatRoleDO updateObj = BeanUtils.toBean(updateReqVO, AiChatRoleDO.class);
chatRoleMapper.updateById(updateObj);
}
+ /**
+ * 校验知识库是否存在
+ *
+ * @param knowledgeIds 知识库编号列表
+ */
+ private void validateDocuments(List knowledgeIds) {
+ if (CollUtil.isEmpty(knowledgeIds)) {
+ return;
+ }
+ // 校验文档是否存在
+ knowledgeIds.forEach(knowledgeService::validateKnowledgeExists);
+ }
+
+ /**
+ * 校验工具是否存在
+ *
+ * @param toolIds 工具编号列表
+ */
+ private void validateTools(List toolIds) {
+ if (CollUtil.isEmpty(toolIds)) {
+ return;
+ }
+ // 遍历校验每个工具是否存在
+ toolIds.forEach(toolService::validateToolExists);
+ }
+
@Override
public void deleteChatRole(Long id) {
// 校验存在
@@ -134,7 +188,8 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
@Override
public List getChatRoleCategoryList() {
List list = chatRoleMapper.selectListGroupByCategory(CommonStatusEnum.ENABLE.getStatus());
- return convertList(list, AiChatRoleDO::getCategory, role -> role != null && StrUtil.isNotBlank(role.getCategory()));
+ return convertList(list, AiChatRoleDO::getCategory,
+ role -> role != null && StrUtil.isNotBlank(role.getCategory()));
}
@Override
@@ -143,4 +198,3 @@ public class AiChatRoleServiceImpl implements AiChatRoleService {
}
}
-
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java
new file mode 100644
index 0000000000..127f72cc46
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelService.java
@@ -0,0 +1,134 @@
+package cn.iocoder.yudao.module.ai.service.model;
+
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
+import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import jakarta.validation.Valid;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.VectorStore;
+
+import javax.annotation.Nullable;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * AI 模型 Service 接口
+ *
+ * @author fansili
+ * @since 2024/4/24 19:42
+ */
+public interface AiModelService {
+
+ /**
+ * 创建模型
+ *
+ * @param createReqVO 创建信息
+ * @return 编号
+ */
+ Long createModel(@Valid AiModelSaveReqVO createReqVO);
+
+ /**
+ * 更新模型
+ *
+ * @param updateReqVO 更新信息
+ */
+ void updateModel(@Valid AiModelSaveReqVO updateReqVO);
+
+ /**
+ * 删除模型
+ *
+ * @param id 编号
+ */
+ void deleteModel(Long id);
+
+ /**
+ * 获得模型
+ *
+ * @param id 编号
+ * @return 模型
+ */
+ AiModelDO getModel(Long id);
+
+ /**
+ * 获得默认的模型
+ *
+ * 如果获取不到,则抛出 {@link cn.iocoder.yudao.framework.common.exception.ServiceException} 业务异常
+ *
+ * @return 模型
+ */
+ AiModelDO getRequiredDefaultModel(Integer type);
+
+ /**
+ * 获得模型分页
+ *
+ * @param pageReqVO 分页查询
+ * @return 模型分页
+ */
+ PageResult getModelPage(AiModelPageReqVO pageReqVO);
+
+ /**
+ * 校验模型是否可使用
+ *
+ * @param id 编号
+ * @return 模型
+ */
+ AiModelDO validateModel(Long id);
+
+ /**
+ * 获得模型列表
+ *
+ * @param status 状态
+ * @param type 类型
+ * @param platform 平台,允许空
+ * @return 模型列表
+ */
+ List getModelListByStatusAndType(Integer status, Integer type,
+ @Nullable String platform);
+
+ // ========== 与 Spring AI 集成 ==========
+
+ /**
+ * 获得 ChatModel 对象
+ *
+ * @param id 编号
+ * @return ChatModel 对象
+ */
+ ChatModel getChatModel(Long id);
+
+ /**
+ * 获得 ImageModel 对象
+ *
+ * @param id 编号
+ * @return ImageModel 对象
+ */
+ ImageModel getImageModel(Long id);
+
+ /**
+ * 获得 MidjourneyApi 对象
+ *
+ * @param id 编号
+ * @return MidjourneyApi 对象
+ */
+ MidjourneyApi getMidjourneyApi(Long id);
+
+ /**
+ * 获得 SunoApi 对象
+ *
+ * @return SunoApi 对象
+ */
+ SunoApi getSunoApi();
+
+ /**
+ * 获得 VectorStore 对象
+ *
+ * @param id 编号
+ * @param metadataFields 元数据的定义
+ * @return VectorStore 对象
+ */
+ VectorStore getOrCreateVectorStore(Long id, Map> metadataFields);
+
+}
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelServiceImpl.java
new file mode 100644
index 0000000000..b0e9e97172
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiModelServiceImpl.java
@@ -0,0 +1,171 @@
+package cn.iocoder.yudao.module.ai.service.model;
+
+import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.factory.AiModelFactory;
+import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
+import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
+import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
+import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
+import jakarta.annotation.Resource;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.embedding.EmbeddingModel;
+import org.springframework.ai.image.ImageModel;
+import org.springframework.ai.vectorstore.SimpleVectorStore;
+import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.stereotype.Service;
+import org.springframework.validation.annotation.Validated;
+
+import java.util.List;
+import java.util.Map;
+
+import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.*;
+
+/**
+ * AI 模型 Service 实现类
+ *
+ * @author fansili
+ */
+@Service
+@Validated
+public class AiModelServiceImpl implements AiModelService {
+
+ @Resource
+ private AiApiKeyService apiKeyService;
+
+ @Resource
+ private AiChatMapper modelMapper;
+
+ @Resource
+ private AiModelFactory modelFactory;
+
+ @Override
+ public Long createModel(AiModelSaveReqVO createReqVO) {
+ // 1. 校验
+ AiPlatformEnum.validatePlatform(createReqVO.getPlatform());
+ apiKeyService.validateApiKey(createReqVO.getKeyId());
+
+ // 2. 插入
+ AiModelDO model = BeanUtils.toBean(createReqVO, AiModelDO.class);
+ modelMapper.insert(model);
+ return model.getId();
+ }
+
+ @Override
+ public void updateModel(AiModelSaveReqVO updateReqVO) {
+ // 1. 校验
+ validateModelExists(updateReqVO.getId());
+ AiPlatformEnum.validatePlatform(updateReqVO.getPlatform());
+ apiKeyService.validateApiKey(updateReqVO.getKeyId());
+
+ // 2. 更新
+ AiModelDO updateObj = BeanUtils.toBean(updateReqVO, AiModelDO.class);
+ modelMapper.updateById(updateObj);
+ }
+
+ @Override
+ public void deleteModel(Long id) {
+ // 校验存在
+ validateModelExists(id);
+ // 删除
+ modelMapper.deleteById(id);
+ }
+
+ private AiModelDO validateModelExists(Long id) {
+ AiModelDO model = modelMapper.selectById(id);
+ if (modelMapper.selectById(id) == null) {
+ throw exception(MODEL_NOT_EXISTS);
+ }
+ return model;
+ }
+
+ @Override
+ public AiModelDO getModel(Long id) {
+ return modelMapper.selectById(id);
+ }
+
+ @Override
+ public AiModelDO getRequiredDefaultModel(Integer type) {
+ AiModelDO model = modelMapper.selectFirstByStatus(type, CommonStatusEnum.ENABLE.getStatus());
+ if (model == null) {
+ throw exception(MODEL_DEFAULT_NOT_EXISTS);
+ }
+ return model;
+ }
+
+ @Override
+ public PageResult getModelPage(AiModelPageReqVO pageReqVO) {
+ return modelMapper.selectPage(pageReqVO);
+ }
+
+ @Override
+ public AiModelDO validateModel(Long id) {
+ AiModelDO model = validateModelExists(id);
+ if (CommonStatusEnum.isDisable(model.getStatus())) {
+ throw exception(MODEL_DISABLE);
+ }
+ return model;
+ }
+
+ @Override
+ public List getModelListByStatusAndType(Integer status, Integer type, String platform) {
+ return modelMapper.selectListByStatusAndType(status, type, platform);
+ }
+
+ // ========== 与 Spring AI 集成 ==========
+
+ @Override
+ public ChatModel getChatModel(Long id) {
+ AiModelDO model = validateModel(id);
+ AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+ return modelFactory.getOrCreateChatModel(platform, apiKey.getApiKey(), apiKey.getUrl());
+ }
+
+ @Override
+ public ImageModel getImageModel(Long id) {
+ AiModelDO model = validateModel(id);
+ AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+ return modelFactory.getOrCreateImageModel(platform, apiKey.getApiKey(), apiKey.getUrl());
+ }
+
+ @Override
+ public MidjourneyApi getMidjourneyApi(Long id) {
+ AiModelDO model = validateModel(id);
+ AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+ return modelFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
+ }
+
+ @Override
+ public SunoApi getSunoApi() {
+ AiApiKeyDO apiKey = apiKeyService.getRequiredDefaultApiKey(
+ AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
+ return modelFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
+ }
+
+ @Override
+ public VectorStore getOrCreateVectorStore(Long id, Map> metadataFields) {
+ // 获取模型 + 密钥
+ AiModelDO model = validateModel(id);
+ AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
+ AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
+
+ // 创建或获取 EmbeddingModel 对象
+ EmbeddingModel embeddingModel = modelFactory.getOrCreateEmbeddingModel(
+ platform, apiKey.getApiKey(), apiKey.getUrl(), model.getModel());
+
+ // 创建或获取 VectorStore 对象
+ return modelFactory.getOrCreateVectorStore(SimpleVectorStore.class, embeddingModel, metadataFields);
+// return modelFactory.getOrCreateVectorStore(QdrantVectorStore.class, embeddingModel, metadataFields);
+// return modelFactory.getOrCreateVectorStore(RedisVectorStore.class, embeddingModel, metadataFields);
+// return modelFactory.getOrCreateVectorStore(MilvusVectorStore.class, embeddingModel, metadataFields);
+ }
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolService.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolService.java
new file mode 100644
index 0000000000..fb23224a83
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolService.java
@@ -0,0 +1,80 @@
+package cn.iocoder.yudao.module.ai.service.model;
+
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
+import jakarta.validation.Valid;
+
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * AI 工具 Service 接口
+ *
+ * @author 芋道源码
+ */
+public interface AiToolService {
+
+ /**
+ * 创建工具
+ *
+ * @param createReqVO 创建信息
+ * @return 编号
+ */
+ Long createTool(@Valid AiToolSaveReqVO createReqVO);
+
+ /**
+ * 更新工具
+ *
+ * @param updateReqVO 更新信息
+ */
+ void updateTool(@Valid AiToolSaveReqVO updateReqVO);
+
+ /**
+ * 删除工具
+ *
+ * @param id 编号
+ */
+ void deleteTool(Long id);
+
+ /**
+ * 校验工具是否存在
+ *
+ * @param id 编号
+ */
+ void validateToolExists(Long id);
+
+ /**
+ * 获得工具
+ *
+ * @param id 编号
+ * @return 工具
+ */
+ AiToolDO getTool(Long id);
+
+ /**
+ * 获得工具列表
+ *
+ * @param ids 编号列表
+ * @return 工具列表
+ */
+ List getToolList(Collection ids);
+
+ /**
+ * 获得工具分页
+ *
+ * @param pageReqVO 分页查询
+ * @return 工具分页
+ */
+ PageResult getToolPage(AiToolPageReqVO pageReqVO);
+
+ /**
+ * 获得工具列表
+ *
+ * @param status 状态
+ * @return 工具列表
+ */
+ List getToolListByStatus(Integer status);
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolServiceImpl.java
new file mode 100644
index 0000000000..59f8f74d1f
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/AiToolServiceImpl.java
@@ -0,0 +1,100 @@
+package cn.iocoder.yudao.module.ai.service.model;
+
+import cn.hutool.extra.spring.SpringUtil;
+import cn.iocoder.yudao.framework.common.pojo.PageResult;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolPageReqVO;
+import cn.iocoder.yudao.module.ai.controller.admin.model.vo.tool.AiToolSaveReqVO;
+import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiToolDO;
+import cn.iocoder.yudao.module.ai.dal.mysql.model.AiToolMapper;
+import jakarta.annotation.Resource;
+import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.stereotype.Service;
+import org.springframework.validation.annotation.Validated;
+
+import java.util.Collection;
+import java.util.List;
+
+import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NAME_NOT_EXISTS;
+import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.TOOL_NOT_EXISTS;
+
+/**
+ * AI 工具 Service 实现类
+ *
+ * @author 芋道源码
+ */
+@Service
+@Validated
+public class AiToolServiceImpl implements AiToolService {
+
+ @Resource
+ private AiToolMapper toolMapper;
+
+ @Override
+ public Long createTool(AiToolSaveReqVO createReqVO) {
+ // 校验名称是否存在
+ validateToolNameExists(createReqVO.getName());
+
+ // 插入
+ AiToolDO tool = BeanUtils.toBean(createReqVO, AiToolDO.class);
+ toolMapper.insert(tool);
+ return tool.getId();
+ }
+
+ @Override
+ public void updateTool(AiToolSaveReqVO updateReqVO) {
+ // 1.1 校验存在
+ validateToolExists(updateReqVO.getId());
+ // 1.2 校验名称是否存在
+ validateToolNameExists(updateReqVO.getName());
+
+ // 2. 更新
+ AiToolDO updateObj = BeanUtils.toBean(updateReqVO, AiToolDO.class);
+ toolMapper.updateById(updateObj);
+ }
+
+ @Override
+ public void deleteTool(Long id) {
+ // 校验存在
+ validateToolExists(id);
+ // 删除
+ toolMapper.deleteById(id);
+ }
+
+ @Override
+ public void validateToolExists(Long id) {
+ if (toolMapper.selectById(id) == null) {
+ throw exception(TOOL_NOT_EXISTS);
+ }
+ }
+
+ private void validateToolNameExists(String name) {
+ try {
+ SpringUtil.getBean(name);
+ } catch (NoSuchBeanDefinitionException e) {
+ throw exception(TOOL_NAME_NOT_EXISTS, name);
+ }
+ }
+
+ @Override
+ public AiToolDO getTool(Long id) {
+ return toolMapper.selectById(id);
+ }
+
+ @Override
+ public List getToolList(Collection ids) {
+ return toolMapper.selectBatchIds(ids);
+ }
+
+ @Override
+ public PageResult getToolPage(AiToolPageReqVO pageReqVO) {
+ return toolMapper.selectPage(pageReqVO);
+ }
+
+ @Override
+ public List getToolListByStatus(Integer status) {
+ return toolMapper.selectListByStatus(status);
+ }
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/DirectoryListToolFunction.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/DirectoryListToolFunction.java
new file mode 100644
index 0000000000..787b2e7728
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/DirectoryListToolFunction.java
@@ -0,0 +1,99 @@
+package cn.iocoder.yudao.module.ai.service.model.tool;
+
+import cn.hutool.core.date.LocalDateTimeUtil;
+import cn.hutool.core.io.FileUtil;
+import cn.hutool.core.util.ArrayUtil;
+import cn.hutool.core.util.StrUtil;
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+
+import static cn.hutool.core.date.DatePattern.NORM_DATETIME_PATTERN;
+import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
+
+/**
+ * 工具:列出指定目录的文件列表
+ *
+ * @author 芋道源码
+ */
+@Component("directory_list")
+public class DirectoryListToolFunction implements Function {
+
+ @Data
+ @JsonClassDescription("列出指定目录的文件列表")
+ public static class Request {
+
+ /**
+ * 目录路径
+ */
+ @JsonProperty(required = true, value = "path")
+ @JsonPropertyDescription("目录路径,例如说:/Users/yunai")
+ private String path;
+
+ }
+
+ @Data
+ @AllArgsConstructor
+ @NoArgsConstructor
+ public static class Response {
+
+ /**
+ * 文件列表
+ */
+ private List files;
+
+ @Data
+ public static class File {
+
+ /**
+ * 是否为目录
+ */
+ private Boolean directory;
+
+ /**
+ * 名称
+ */
+ private String name;
+
+ /**
+ * 大小,仅对文件有效
+ */
+ private String size;
+
+ /**
+ * 最后修改时间
+ */
+ private String lastModified;
+
+ }
+
+ }
+
+ @Override
+ public Response apply(Request request) {
+ // 校验目录存在
+ String path = StrUtil.blankToDefault(request.getPath(), "/");
+ if (!FileUtil.exist(path) || !FileUtil.isDirectory(path)) {
+ return new Response(Collections.emptyList());
+ }
+ // 列出目录
+ File[] files = FileUtil.ls(path);
+ if (ArrayUtil.isEmpty(files)) {
+ return new Response(Collections.emptyList());
+ }
+ return new Response(convertList(Arrays.asList(files), file ->
+ new Response.File().setDirectory(file.isDirectory()).setName(file.getName())
+ .setLastModified(LocalDateTimeUtil.format(LocalDateTimeUtil.of(file.lastModified()), NORM_DATETIME_PATTERN))));
+ }
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/WeatherQueryToolFunction.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/WeatherQueryToolFunction.java
new file mode 100644
index 0000000000..99262fafad
--- /dev/null
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/WeatherQueryToolFunction.java
@@ -0,0 +1,118 @@
+package cn.iocoder.yudao.module.ai.service.model.tool;
+
+import cn.hutool.core.date.LocalDateTimeUtil;
+import cn.hutool.core.util.RandomUtil;
+import cn.hutool.core.util.StrUtil;
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.stereotype.Component;
+
+import java.time.LocalDateTime;
+import java.util.function.Function;
+
+import static cn.hutool.core.date.DatePattern.NORM_DATETIME_PATTERN;
+
+/**
+ * 工具:查询指定城市的天气信息
+ *
+ * @author 芋道源码
+ */
+@Component("weather_query")
+public class WeatherQueryToolFunction
+ implements Function {
+
+ private static final String[] WEATHER_CONDITIONS = { "晴朗", "多云", "阴天", "小雨", "大雨", "雷雨", "小雪", "大雪" };
+
+ @Data
+ @JsonClassDescription("查询指定城市的天气信息")
+ public static class Request {
+
+ /**
+ * 城市名称
+ */
+ @JsonProperty(required = true, value = "city")
+ @JsonPropertyDescription("城市名称,例如:北京、上海、广州")
+ private String city;
+
+ }
+
+ @Data
+ @AllArgsConstructor
+ @NoArgsConstructor
+ public static class Response {
+
+ /**
+ * 城市名称
+ */
+ private String city;
+
+ /**
+ * 天气信息
+ */
+ private WeatherInfo weatherInfo;
+
+ @Data
+ @AllArgsConstructor
+ @NoArgsConstructor
+ public static class WeatherInfo {
+
+ /**
+ * 温度(摄氏度)
+ */
+ private Integer temperature;
+
+ /**
+ * 天气状况
+ */
+ private String condition;
+
+ /**
+ * 湿度百分比
+ */
+ private Integer humidity;
+
+ /**
+ * 风速(km/h)
+ */
+ private Integer windSpeed;
+
+ /**
+ * 查询时间
+ */
+ private String queryTime;
+
+ }
+
+ }
+
+ @Override
+ public Response apply(Request request) {
+ // 检查城市名称是否为空
+ if (StrUtil.isBlank(request.getCity())) {
+ return new Response("未知城市", null);
+ }
+
+ // 获取天气数据
+ String city = request.getCity();
+ Response.WeatherInfo weatherInfo = generateMockWeatherInfo();
+ return new Response(city, weatherInfo);
+ }
+
+ /**
+ * 生成模拟的天气数据
+ * 在实际应用中,应替换为真实 API 调用
+ */
+ private Response.WeatherInfo generateMockWeatherInfo() {
+ int temperature = RandomUtil.randomInt(-5, 30);
+ int humidity = RandomUtil.randomInt(1, 100);
+ int windSpeed = RandomUtil.randomInt(1, 30);
+ String condition = RandomUtil.randomEle(WEATHER_CONDITIONS);
+ return new Response.WeatherInfo(temperature, condition, humidity, windSpeed,
+ LocalDateTimeUtil.format(LocalDateTime.now(), NORM_DATETIME_PATTERN));
+ }
+
+}
\ No newline at end of file
diff --git a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java
index 3f10ec8402..e4ff81a477 100644
--- a/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java
+++ b/yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/music/AiMusicServiceImpl.java
@@ -16,7 +16,7 @@ import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
-import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
+import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
@@ -41,7 +41,7 @@ import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXIS
public class AiMusicServiceImpl implements AiMusicService {
@Resource
- private AiApiKeyService apiKeyService;
+ private AiModelService modelService;
@Resource
private AiMusicMapper musicMapper;
@@ -53,7 +53,7 @@ public class AiMusicServiceImpl implements AiMusicService {
@Transactional(rollbackFor = Exception.class)
public List generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐
- SunoApi sunoApi = apiKeyService.getSunoApi();
+ SunoApi sunoApi = modelService.getSunoApi();
List musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.1 描述模式
@@ -88,7 +88,7 @@ public class AiMusicServiceImpl implements AiMusicService {
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
// GET 请求,为避免参数过长,分批次处理
- SunoApi sunoApi = apiKeyService.getSunoApi();
+ SunoApi sunoApi = modelService.getSunoApi();
CollUtil.split(streamingTask, 36).forEach(chunkList -> {
Map