Skip to content

Commit

Permalink
添加对通用文本向量的支持
Browse files Browse the repository at this point in the history
  • Loading branch information
XYWENJIE committed Feb 20, 2024
1 parent 8e033f1 commit dba0f9e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README-zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

##功能特性

**支持阿里云服务的通义千问和通义万相**
**支持阿里云的(DashScope)灵积服务通义千问和通义万相**

提示:如果要本地通义千问模型可以使用官方的Ollama模块进行

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ public record ChatCompletionRequest(
@JsonProperty("parameters") Parameters parameters) {

public ChatCompletionRequest(List<ChatCompletionMessage> messages,String model,Float temperature) {
this(model,new Input(null, messages),new Parameters("message", null, null, null, null, null, temperature, null, null));
this(model,new Input(null, messages,null),new Parameters("message", null, null, null, null, null, temperature, null, null,null));
}
}

@JsonInclude(Include.NON_NULL)
public record Input(
@JsonProperty("prompt") String prompt,
@JsonProperty("messages") List<ChatCompletionMessage> messages) {
@JsonProperty("messages") List<ChatCompletionMessage> messages,
@JsonProperty("texts")List<String> texts) {

}

Expand All @@ -160,7 +161,12 @@ public record Parameters(
@JsonProperty("repetition_penalty")Float repetitionPenlty,
@JsonProperty("temperature") Float temperature,
@JsonProperty("stop")List<String> stop,
@JsonProperty("incremental_output")Boolean incrementalOutput) {
@JsonProperty("incremental_output")Boolean incrementalOutput,
@JsonProperty("text_type")String textType) {

public Parameters(String textType) {
this(null,null,null,null,null,null,null,null, null, textType);
}
}

@JsonInclude(Include.NON_NULL)
Expand All @@ -185,7 +191,14 @@ public record Output(
@JsonProperty("task_id") String taskId,
@JsonProperty("task_status") StatusStatus taskStatus,
@JsonProperty("task_metrics") TaskMetrices taskMetrices,
@JsonProperty("results") List<Results> results) {
@JsonProperty("results") List<Results> results,
@JsonProperty("embeddings") List<Embedding> embeddings) {

}

public record Embedding(
@JsonProperty("text_index")Integer textIndex,
@JsonProperty("embedding") List<Double> embedding) {

}

Expand Down Expand Up @@ -235,6 +248,20 @@ public record QWenImageResponse(

}

@JsonInclude(Include.NON_NULL)
public record EmbeddingRequest(
String model,Input input,Parameters parameters) {

public EmbeddingRequest(String model,List<String> texts,String textType) {
this(model, new Input(null, null,texts), new Parameters(textType));
}
}

public record EmbeddingResponse(
String requestId,Usage usage,Output output) {

}

public enum StatusStatus{
PENDING,RUNNING,SUCCEEDED,FAILED,UNKNOWN,
}
Expand Down Expand Up @@ -265,6 +292,13 @@ public ResponseEntity<QWenImageResponse> findImageTaskResult(String taskId){
return this.restClient.get().uri("/api/v1/tasks/{task_id}",taskId).retrieve().toEntity(QWenImageResponse.class);
}

public ResponseEntity<EmbeddingResponse> embeddingRequest(EmbeddingRequest embeddingRequest) {
return this.restClient.post()
.uri("/api/v1/services/embeddings/text-embedding/text-embedding")
.body(embeddingRequest).retrieve()
.toEntity(EmbeddingResponse.class);
}

public static void main(String[] args) throws JsonProcessingException {
String body = "{\"output\":{\"finish_reason\":\"stop\",\"text\":\"我是阿里云开发的一款超大规模语言模型,我叫通义千问。\"},\"usage\":{\"total_tokens\":20,\"output_tokens\":17,\"input_tokens\":3},\"request_id\":\"f978f627-fd0f-91fd-be5d-b3ec1ac394b1\"}";
ObjectMapper objectMapper = new ObjectMapper();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
package org.springframework.ai.dashcope.qwen;

import static org.assertj.core.api.Assertions.contentOf;

import java.time.Duration;
import java.util.List;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.dashcope.DashCopeService;
import org.springframework.ai.dashcope.DashCopeService.DashCopeApiException;
import org.springframework.ai.dashcope.DashCopeService.Usage;
import org.springframework.ai.dashcope.metadata.support.EmbeddingModel;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingClient;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

/**
Expand All @@ -17,8 +29,19 @@
*/
public class QWenEmbeddingClient extends AbstractEmbeddingClient {

private static final Logger logger = LoggerFactory.getLogger(QWenEmbeddingClient.class);

public static final EmbeddingModel DEFAULT_OPENAI_EMBEDDING_MODEL = EmbeddingModel.TEXT_EMBEDDING_V1;

private final RetryTemplate retryTemplate = RetryTemplate.builder().maxAttempts(10)
.retryOn(DashCopeApiException.class)
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3* 60000))
.withListener(new RetryListener() {
public <T extends Object, E extends Throwable> void onError(org.springframework.retry.RetryContext context, org.springframework.retry.RetryCallback<T,E> callback, Throwable throwable) {
logger.warn("Retry error. Retry count:"+context.getRetryCount(),throwable);
};
}).build();

private final QWenEmbeddingOptions embeddingOptions;

private final DashCopeService dashCopeService;
Expand Down Expand Up @@ -51,8 +74,23 @@ public List<Double> embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
// TODO Auto-generated method stub
return null;
return this.retryTemplate.execute(ctx -> {
org.springframework.ai.dashcope.DashCopeService.EmbeddingRequest embeddingRequest = new org.springframework.ai.dashcope.DashCopeService.EmbeddingRequest(this.embeddingOptions.getModel().value, request.getInstructions(), "document");
org.springframework.ai.dashcope.DashCopeService.EmbeddingResponse embeddingResponse = this.dashCopeService.embeddingRequest(embeddingRequest).getBody();
if(embeddingResponse == null) {
logger.warn("No {}",embeddingRequest);
}
var metadata = generateResponseMetadata(embeddingRequest.model(),embeddingResponse.usage());
List<Embedding> embeddings = embeddingResponse.output().embeddings().stream().map(e-> new Embedding(e.embedding(),e.textIndex())).toList();
return new EmbeddingResponse(embeddings,metadata);
});
}

private EmbeddingResponseMetadata generateResponseMetadata(String model,Usage usage) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.put("model", model);
metadata.put("total-tokens", usage.totalTokens());
return metadata;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,9 @@ public QWenEmbeddingOptions build() {
return this.embeddingOptions;
}
}

public EmbeddingModel getModel() {
return model;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public ImageResponse call(ImagePrompt imagePrompt) {

String instructions = imagePrompt.getInstructions().get(0).getText();

QWenImageRequest imageRequest = new QWenImageRequest(new Input(instructions,null), null);
QWenImageRequest imageRequest = new QWenImageRequest(new Input(instructions,null,null), null);
ResponseEntity<QWenImageResponse> responseEntity = this.dashCopeService.createQwenImageTask(imageRequest);
return responseEntity.getBody();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import java.util.List;

import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.dashcope.chat.QwenChatClientIT;
import org.springframework.ai.dashcope.qwen.QWenEmbeddingClient;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.beans.factory.annotation.Autowired;
Expand All @@ -13,6 +16,8 @@
@SpringBootTest
public class EmbeddingIT {

private final Logger logger = LoggerFactory.getLogger(EmbeddingIT.class);

@Autowired
private QWenEmbeddingClient embeddingClient;

Expand All @@ -22,6 +27,7 @@ public void simpleEmbedding() {

EmbeddingResponse embeddingResponse = embeddingClient.embedForResponse(List.of("Hello word"));
assertThat(embeddingResponse.getResults()).hasSize(1);
logger.info("{}",embeddingResponse.getResult().getOutput());
assertThat(embeddingResponse.getResults().get(0)).isNotNull();
//assertThat(embeddingResponse.getMetadata()).containsEntry("model", "");
}
Expand Down

0 comments on commit dba0f9e

Please sign in to comment.