classSiliconFlowEmbedding(BaseEmbedding):"""SiliconFlow class for embeddings."""model:str=Field(default="BAAI/bge-m3",description="""\ The name of the embedding model to use. 512 tokens for all models input except `bge-m3` which is 8192. """,)api_key:Optional[str]=Field(default=None,description="The SiliconFlow API key.",)base_url:str=Field(default=DEFAULT_SILICONFLOW_API_URL,description="The base URL for the SiliconFlow API.",)encoding_format:str=Field(default="float",description="The format to return the embeddings in. Can be either float or base64.",)# TODO: Consider whether to fix the encoding format as float.max_retries:int=Field(default=3,description="The maximum number of API retries.",ge=0,)_headers:Any=PrivateAttr()def__init__(self,model:str="BAAI/bge-m3",api_key:Optional[str]=None,base_url:str=DEFAULT_SILICONFLOW_API_URL,encoding_format:Optional[str]="float",max_retries:int=3,callback_manager:Optional[CallbackManager]=None,**kwargs:Any,)->None:super().__init__(model=model,api_key=api_key,base_url=base_url,encoding_format=encoding_format,max_retries=max_retries,callback_manager=callback_manager,**kwargs,)assertself.encoding_formatinVALID_ENCODING,f"""\ Encoding_format parameter {self.encoding_format} not supported. Please choose one of {VALID_ENCODING}". """self._headers={"Authorization":f"Bearer {api_key}","Content-Type":"application/json",}@classmethoddefclass_name(cls)->str:return"SiliconFlowEmbedding"def_data_formatting(self,response:list)->List[List[float]]:results=sorted(response["data"],key=lambdae:e["index"])ifself.encoding_format=="base64":return[base64_to_float_list(data["embedding"])fordatainresults]else:return[data["embedding"]fordatainresults]def_get_query_embedding(self,query:str)->List[float]:"""Get query embedding."""returnself._get_text_embeddings([query])[0]asyncdef_aget_query_embedding(self,query:str)->List[float]:"""The asynchronous version of _get_query_embedding."""result=awaitself._aget_text_embeddings([query])returnresult[0]def_get_text_embedding(self,text:str)->List[float]:"""Get text embedding."""returnself._get_text_embeddings([text])[0]asyncdef_aget_text_embedding(self,text:str)->List[float]:"""Asynchronously get text embedding."""result=awaitself._aget_text_embeddings([text])returnresult[0]@embedding_retry_decoratordef_get_text_embeddings(self,texts:List[str])->List[List[float]]:withrequests.Session()assession:input_json={"model":self.model,"input":texts,"encoding_format":self.encoding_format,}response=session.post(self.base_url,json=input_json,headers=self._headers).json()if"data"notinresponse:raiseRuntimeError(response)returnself._data_formatting(response)@embedding_retry_decoratorasyncdef_aget_text_embeddings(self,texts:List[str],)->List[List[float]]:asyncwithaiohttp.ClientSession()assession:input_json={"input":texts,"model":self.model,"encoding_format":self.encoding_format,}asyncwithsession.post(self.base_url,json=input_json,headers=self._headers)asresponse:response_json=awaitresponse.json()response.raise_for_status()returnself._data_formatting(response_json)