94 lines
3.4 KiB
Python
94 lines
3.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
from typing import Any, Optional, Self, Sequence, Callable
|
|
from requests import post as Post, Response
|
|
# from pydantic import BaseModel
|
|
from Interfaces.Application.AnPInterface import AnPInterface
|
|
from Abstracts.AIInterpretersAbstract import AIInterpretersAbstract
|
|
from Abstracts.ModelAbstract import ModelAbstract
|
|
from Models.AIResponseModel import AIResponseModel
|
|
from Utils.Checks import Check
|
|
from Utils.Common import Common
|
|
|
|
class OllamaDriver(AIInterpretersAbstract, ModelAbstract):
|
|
|
|
def __init__(self:Self,
|
|
anp:AnPInterface,
|
|
key:str,
|
|
inputs:Optional[dict[str, Any|None]|Sequence[Any|None]] = None
|
|
) -> None:
|
|
super().__init__(anp, key, inputs)
|
|
|
|
def request(self:Self,
|
|
session:int|None,
|
|
message:str,
|
|
callback:Optional[Callable[[int, AIResponseModel], None]] = None,
|
|
orders:str|list[str] = [],
|
|
custom_context:Optional[list[int]] = None
|
|
) -> tuple[int|None, AIResponseModel]:
|
|
|
|
response:Response
|
|
context:list[int]
|
|
options:dict[str, Any] = {}
|
|
results:AIResponseModel = AIResponseModel()
|
|
|
|
if self.maximum_response_tokens is not None:
|
|
options["num_predict"] = self.maximum_response_tokens
|
|
|
|
if self.maximum_tokens_per_session is not None:
|
|
options["num_ctx"] = self.maximum_tokens_per_session
|
|
|
|
session, context = custom_context or self.get_session(session)
|
|
orders = self.get_orders(orders)
|
|
|
|
if custom_context:
|
|
context = context.copy()
|
|
|
|
with Post(self.url, json = {
|
|
"model" : self.model,
|
|
"prompt": message,
|
|
**({"system": orders} if len(orders) else {}),
|
|
"stream": self.stream,
|
|
"think" : self.think,
|
|
**(
|
|
{"format" : self.format} if (
|
|
self.format == "json" or
|
|
# Check.is_array(self.format) or
|
|
Check.is_dictionary(self.format)
|
|
) else
|
|
# {"format" : self.format.model_json_schema()} if isinstance(self.format, BaseModel) else
|
|
{}),
|
|
**({"context" : context} if len(context) else {}),
|
|
**({"options" : options} if len(options) else {})
|
|
}, stream = self.stream) as response:
|
|
results.http_code = response.status_code
|
|
if results.http_code == 200:
|
|
results.ok = True
|
|
try:
|
|
|
|
chunk:bytes
|
|
|
|
for chunk in response.iter_lines():
|
|
if not self.anp.working():
|
|
break
|
|
if chunk:
|
|
results.update(Common.json_decode(chunk))
|
|
if results.done:
|
|
self.save_context(session, results.context)
|
|
break
|
|
|
|
Common.execute(callback, session, results)
|
|
|
|
except Exception as exception:
|
|
self.anp.exception(exception, "anp_ollama_driver_request")
|
|
results.ok = False
|
|
else:
|
|
try:
|
|
results.http_message = response.json().get("error", "unknown_error")
|
|
except Exception as _:
|
|
results.http_message = "unknown_error"
|
|
|
|
Common.execute(callback, session, results)
|
|
|
|
return session, results |