feat(native): doc loader for common native (#9941)

This commit is contained in:
darkskygit
2025-02-25 07:50:56 +00:00
parent 26674b0cb8
commit 842c39c3be
42 changed files with 2989 additions and 36 deletions

View File

@@ -0,0 +1,169 @@
use std::{io::Cursor, path::PathBuf};
use path_ext::PathExt;
use super::*;
#[derive(Clone, Default)]
pub struct Chunk {
pub index: usize,
pub content: String,
pub start: Option<usize>,
pub end: Option<usize>,
}
pub struct DocOptions {
code_threshold: u64,
}
impl Default for DocOptions {
fn default() -> Self {
Self {
code_threshold: 1000,
}
}
}
pub struct Doc {
pub name: String,
pub chunks: Vec<Chunk>,
}
impl Doc {
pub fn new(file_path: &str, doc: &[u8]) -> Option<Self> {
Self::with_options(file_path, doc, DocOptions::default())
}
pub fn with_options(file_path: &str, doc: &[u8], options: DocOptions) -> Option<Self> {
if let Some(kind) =
infer::get(&doc[..4096.min(doc.len())]).or(infer::get_from_path(file_path).ok().flatten())
{
if kind.extension() == "pdf" {
return Self::load_pdf(file_path, doc);
} else if kind.extension() == "docx" {
return Self::load_docx(file_path, doc);
} else if kind.extension() == "html" {
return Self::load_html(file_path, doc);
}
} else if let Ok(string) = String::from_utf8(doc.to_vec()).or_else(|_| {
String::from_utf16(
&doc
.chunks_exact(2)
.map(|b| u16::from_le_bytes([b[0], b[1]]))
.collect::<Vec<_>>(),
)
}) {
let path = PathBuf::from(file_path);
match path.ext_str() {
"md" => {
let loader = TextLoader::new(string);
let splitter = MarkdownSplitter::default();
return Self::from_loader(file_path, loader, splitter).ok();
}
"rs" | "c" | "cpp" | "h" | "hpp" | "js" | "ts" | "tsx" | "go" | "py" => {
let name = path.full_str().to_string();
let loader =
SourceCodeLoader::from_string(string).with_parser_option(LanguageParserOptions {
language: get_language_by_filename(&name).ok()?,
parser_threshold: options.code_threshold,
});
let splitter = TokenSplitter::default();
return Self::from_loader(file_path, loader, splitter).ok();
}
_ => {}
}
let loader = TextLoader::new(string);
let splitter = TokenSplitter::default();
return Self::from_loader(file_path, loader, splitter).ok();
}
None
}
fn from_loader(
file_path: &str,
loader: impl Loader,
splitter: impl TextSplitter + 'static,
) -> Result<Doc, LoaderError> {
let name = file_path.to_string();
let chunks = Self::get_chunks_from_loader(loader, splitter)?;
Ok(Self { name, chunks })
}
fn get_chunks_from_loader(
loader: impl Loader,
splitter: impl TextSplitter + 'static,
) -> Result<Vec<Chunk>, LoaderError> {
let docs = loader.load_and_split(splitter)?;
Ok(
docs
.into_iter()
.enumerate()
.map(|(index, d)| Chunk {
index,
content: d.page_content,
..Chunk::default()
})
.collect(),
)
}
fn load_docx(file_path: &str, doc: &[u8]) -> Option<Self> {
let loader = DocxLoader::new(Cursor::new(doc))?;
let splitter = TokenSplitter::default();
Self::from_loader(file_path, loader, splitter).ok()
}
fn load_html(file_path: &str, doc: &[u8]) -> Option<Self> {
let loader = HtmlLoader::from_string(
String::from_utf8(doc.to_vec()).ok()?,
Url::parse(file_path)
.or(Url::parse("https://example.com/"))
.ok()?,
);
let splitter = TokenSplitter::default();
Self::from_loader(file_path, loader, splitter).ok()
}
fn load_pdf(file_path: &str, doc: &[u8]) -> Option<Self> {
let loader = PdfExtractLoader::new(Cursor::new(doc)).ok()?;
let splitter = TokenSplitter::default();
Self::from_loader(file_path, loader, splitter).ok()
}
}
#[cfg(test)]
mod tests {
use std::{
fs::{read, read_to_string},
path::PathBuf,
};
use super::*;
const FIXTURES: [&str; 6] = [
"demo.docx",
"sample.pdf",
"sample.html",
"sample.rs",
"sample.c",
"sample.ts",
];
fn get_fixtures() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("fixtures")
}
#[test]
fn test_fixtures() {
let fixtures = get_fixtures();
for fixture in FIXTURES.iter() {
let buffer = read(fixtures.join(fixture)).unwrap();
let doc = Doc::with_options(fixture, &buffer, DocOptions { code_threshold: 0 }).unwrap();
for chunk in doc.chunks.iter() {
let output =
read_to_string(fixtures.join(format!("{}.{}.md", fixture, chunk.index))).unwrap();
assert_eq!(chunk.content, output);
}
}
}
}

View File

@@ -0,0 +1,71 @@
use docx_parser::MarkdownDocument;
use super::*;
#[derive(Debug)]
pub struct DocxLoader {
document: MarkdownDocument,
}
impl DocxLoader {
pub fn new<R: Read + Seek>(reader: R) -> Option<Self> {
Some(Self {
document: MarkdownDocument::from_reader(reader)?,
})
}
fn extract_text(&self) -> String {
self.document.to_markdown(false)
}
fn extract_text_to_doc(&self) -> Document {
Document::new(self.extract_text())
}
}
impl Loader for DocxLoader {
fn load(self) -> Result<Vec<Document>, LoaderError> {
let doc = self.extract_text_to_doc();
Ok(vec![doc])
}
}
#[cfg(test)]
mod tests {
use std::{fs::read, io::Cursor, path::PathBuf};
use super::*;
fn get_fixtures_path() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("fixtures")
}
#[test]
fn test_parse_docx() {
let docx_buffer = include_bytes!("../../../fixtures/demo.docx");
let parsed_buffer = include_str!("../../../fixtures/demo.docx.md");
{
let loader = DocxLoader::new(Cursor::new(docx_buffer)).unwrap();
let documents = loader.load().unwrap();
assert_eq!(documents.len(), 1);
assert_eq!(documents[0].page_content, parsed_buffer);
}
{
let loader = DocxLoader::new(Cursor::new(docx_buffer)).unwrap();
let documents = loader.load_and_split(TokenSplitter::default()).unwrap();
for (idx, doc) in documents.into_iter().enumerate() {
assert_eq!(
doc.page_content,
String::from_utf8_lossy(
&read(get_fixtures_path().join(format!("demo.docx.{}.md", idx))).unwrap()
)
);
}
}
}
}

View File

@@ -0,0 +1,42 @@
use std::{io, str::Utf8Error, string::FromUtf8Error};
use thiserror::Error;
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
use super::*;
#[derive(Error, Debug)]
pub enum LoaderError {
#[error("{0}")]
TextSplitterError(#[from] TextSplitterError),
#[error(transparent)]
IOError(#[from] io::Error),
#[error(transparent)]
Utf8Error(#[from] Utf8Error),
#[error(transparent)]
FromUtf8Error(#[from] FromUtf8Error),
#[cfg(feature = "pdf-extract")]
#[error(transparent)]
PdfExtractError(#[from] pdf_extract::Error),
#[cfg(feature = "pdf-extract")]
#[error(transparent)]
PdfExtractOutputError(#[from] pdf_extract::OutputError),
#[error(transparent)]
ReadabilityError(#[from] readability::error::Error),
#[error("Unsupported source language")]
UnsupportedLanguage,
#[error("Error: {0}")]
OtherError(String),
}
pub type LoaderResult<T> = Result<T, LoaderError>;

View File

@@ -0,0 +1,87 @@
use std::{collections::HashMap, io::Cursor};
use serde_json::Value;
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
use super::*;
#[derive(Debug, Clone)]
pub struct HtmlLoader<R> {
html: R,
url: Url,
}
impl HtmlLoader<Cursor<Vec<u8>>> {
pub fn from_string<S: Into<String>>(input: S, url: Url) -> Self {
let input = input.into();
let reader = Cursor::new(input.into_bytes());
Self::new(reader, url)
}
}
impl<R: Read> HtmlLoader<R> {
pub fn new(html: R, url: Url) -> Self {
Self { html, url }
}
}
impl<R: Read + Send + Sync + 'static> Loader for HtmlLoader<R> {
fn load(mut self) -> Result<Vec<Document>, LoaderError> {
let cleaned_html = readability::extractor::extract(&mut self.html, &self.url)?;
let doc =
Document::new(format!("{}\n{}", cleaned_html.title, cleaned_html.text)).with_metadata(
HashMap::from([("source".to_string(), Value::from(self.url.as_str()))]),
);
Ok(vec![doc])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_html_loader() {
let input = "<p>Hello world!</p>";
let html_loader = HtmlLoader::new(
input.as_bytes(),
Url::parse("https://example.com/").unwrap(),
);
let documents = html_loader.load().unwrap();
let expected = "\nHello world!";
assert_eq!(documents.len(), 1);
assert_eq!(
documents[0].metadata.get("source").unwrap(),
&Value::from("https://example.com/")
);
assert_eq!(documents[0].page_content, expected);
}
#[test]
fn test_html_load_from_path() {
let buffer = include_bytes!("../../../fixtures/sample.html");
let html_loader = HtmlLoader::new(
Cursor::new(buffer),
Url::parse("https://example.com/").unwrap(),
);
let documents = html_loader.load().unwrap();
let expected = "Example Domain\n\n This domain is for use in illustrative examples in \
documents. You may\n use this domain in literature without prior \
coordination or asking for\n permission.\n More information...";
assert_eq!(documents.len(), 1);
assert_eq!(
documents[0].metadata.get("source").unwrap(),
&Value::from("https://example.com/")
);
assert_eq!(documents[0].page_content, expected);
}
}

View File

@@ -0,0 +1,33 @@
mod docx;
mod error;
mod html;
mod pdf;
mod source;
mod text;
use std::io::{Read, Seek};
use super::*;
// modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
pub trait Loader: Send + Sync {
fn load(self) -> Result<Vec<Document>, LoaderError>;
fn load_and_split<TS: TextSplitter + 'static>(
self,
splitter: TS,
) -> Result<Vec<Document>, LoaderError>
where
Self: Sized,
{
let docs = self.load()?;
Ok(splitter.split_documents(&docs)?)
}
}
pub use docx::DocxLoader;
pub use error::{LoaderError, LoaderResult};
pub use html::HtmlLoader;
pub use pdf::PdfExtractLoader;
pub use source::{get_language_by_filename, LanguageParserOptions, SourceCodeLoader};
pub use text::TextLoader;
pub use url::Url;

View File

@@ -0,0 +1,70 @@
use pdf_extract::{output_doc, output_doc_encrypted, PlainTextOutput};
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
use super::*;
#[derive(Debug, Clone)]
pub struct PdfExtractLoader {
document: pdf_extract::Document,
}
impl PdfExtractLoader {
pub fn new<R: Read>(reader: R) -> Result<Self, LoaderError> {
let document = pdf_extract::Document::load_from(reader)
.map_err(|e| LoaderError::OtherError(e.to_string()))?;
Ok(Self { document })
}
}
impl PdfExtractLoader {
fn extract_text(&self) -> Result<String, LoaderError> {
let mut doc = self.document.clone();
let mut buffer: Vec<u8> = Vec::new();
let mut output = PlainTextOutput::new(&mut buffer as &mut dyn std::io::Write);
if doc.is_encrypted() {
output_doc_encrypted(&mut doc, &mut output, "")?;
} else {
output_doc(&doc, &mut output)?;
}
Ok(String::from_utf8(buffer)?)
}
fn extract_text_to_doc(&self) -> Result<Document, LoaderError> {
let text = self.extract_text()?;
Ok(Document::new(text))
}
}
impl Loader for PdfExtractLoader {
fn load(self) -> Result<Vec<Document>, LoaderError> {
let doc = self.extract_text_to_doc()?;
Ok(vec![doc])
}
}
#[cfg(test)]
mod tests {
use std::{fs::read, io::Cursor, path::PathBuf};
use super::*;
#[test]
fn test_parse_pdf() {
let fixtures = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("fixtures");
let buffer = read(fixtures.join("sample.pdf")).unwrap();
let reader = Cursor::new(buffer);
let loader = PdfExtractLoader::new(reader).expect("Failed to create PdfExtractLoader");
let docs = loader.load().unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(
&docs[0].page_content[..100],
"\n\nSample PDF\nThis is a simple PDF file. Fun fun fun.\n\nLorem ipsum dolor sit amet, \
consectetuer a"
);
}
}

View File

@@ -0,0 +1,61 @@
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
mod parser;
pub use parser::{get_language_by_filename, LanguageParser, LanguageParserOptions};
use super::*;
#[derive(Debug, Clone)]
pub struct SourceCodeLoader {
content: String,
parser_option: LanguageParserOptions,
}
impl SourceCodeLoader {
pub fn from_string<S: Into<String>>(input: S) -> Self {
Self {
content: input.into(),
parser_option: LanguageParserOptions::default(),
}
}
}
impl SourceCodeLoader {
pub fn with_parser_option(mut self, parser_option: LanguageParserOptions) -> Self {
self.parser_option = parser_option;
self
}
}
impl Loader for SourceCodeLoader {
fn load(self) -> Result<Vec<Document>, LoaderError> {
let options = self.parser_option.clone();
let docs = LanguageParser::from_language(options.language)
.with_parser_threshold(options.parser_threshold)
.parse_code(&self.content)?;
Ok(docs)
}
}
#[cfg(test)]
mod tests {
use parser::Language;
use super::*;
#[test]
fn test_source_code_loader() {
let content = include_str!("../../../../fixtures/sample.rs");
let loader = SourceCodeLoader::from_string(content).with_parser_option(LanguageParserOptions {
language: Language::Rust,
..Default::default()
});
let documents_with_content = loader.load().unwrap();
assert_eq!(documents_with_content.len(), 1);
}
}

View File

@@ -0,0 +1,246 @@
use std::{collections::HashMap, fmt::Debug, string::ToString};
use strum_macros::Display;
use tree_sitter::{Parser, Tree};
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
use super::*;
#[derive(Display, Debug, Clone)]
pub enum Language {
Rust,
C,
Cpp,
Javascript,
Typescript,
Go,
Python,
}
pub enum LanguageContentTypes {
SimplifiedCode,
FunctionsImpls,
}
impl std::fmt::Display for LanguageContentTypes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
match self {
LanguageContentTypes::SimplifiedCode => "simplified_code",
LanguageContentTypes::FunctionsImpls => "functions_impls",
}
)
}
}
#[derive(Debug, Clone)]
pub struct LanguageParserOptions {
pub parser_threshold: u64,
pub language: Language,
}
impl Default for LanguageParserOptions {
fn default() -> Self {
Self {
parser_threshold: 1000,
language: Language::Rust,
}
}
}
pub struct LanguageParser {
parser: Parser,
parser_options: LanguageParserOptions,
}
impl Debug for LanguageParser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LanguageParser {{ language: {:?} }}",
self.parser_options.language
)
}
}
impl Clone for LanguageParser {
fn clone(&self) -> Self {
LanguageParser {
parser: get_language_parser(&self.parser_options.language),
parser_options: self.parser_options.clone(),
}
}
}
pub fn get_language_by_filename(name: &str) -> LoaderResult<Language> {
let extension = name
.split('.')
.last()
.ok_or(LoaderError::UnsupportedLanguage)?;
let language = match extension.to_lowercase().as_str() {
"rs" => Language::Rust,
"c" => Language::C,
"cpp" => Language::Cpp,
"h" => Language::C,
"hpp" => Language::Cpp,
"js" => Language::Javascript,
"ts" => Language::Typescript,
"tsx" => Language::Typescript,
"go" => Language::Go,
"py" => Language::Python,
_ => return Err(LoaderError::UnsupportedLanguage),
};
Ok(language)
}
fn get_language_parser(language: &Language) -> Parser {
let mut parser = Parser::new();
let lang = match language {
Language::Rust => tree_sitter_rust::LANGUAGE,
Language::C => tree_sitter_c::LANGUAGE,
Language::Cpp => tree_sitter_cpp::LANGUAGE,
Language::Javascript => tree_sitter_javascript::LANGUAGE,
Language::Typescript => tree_sitter_typescript::LANGUAGE_TSX,
Language::Go => tree_sitter_go::LANGUAGE,
Language::Python => tree_sitter_python::LANGUAGE,
};
parser
.set_language(&lang.into())
.unwrap_or_else(|_| panic!("Error loading grammar for language: {:?}", language));
parser
}
impl LanguageParser {
pub fn from_language(language: Language) -> Self {
Self {
parser: get_language_parser(&language),
parser_options: LanguageParserOptions {
language,
..LanguageParserOptions::default()
},
}
}
pub fn with_parser_threshold(mut self, threshold: u64) -> Self {
self.parser_options.parser_threshold = threshold;
self
}
}
impl LanguageParser {
pub fn parse_code(&mut self, code: &String) -> LoaderResult<Vec<Document>> {
let tree = self
.parser
.parse(code, None)
.ok_or(LoaderError::UnsupportedLanguage)?;
if self.parser_options.parser_threshold > tree.root_node().end_position().row as u64 {
return Ok(vec![Document::new(code).with_metadata(HashMap::from([
(
"content_type".to_string(),
serde_json::Value::from(LanguageContentTypes::SimplifiedCode.to_string()),
),
(
"language".to_string(),
serde_json::Value::from(self.parser_options.language.to_string()),
),
]))]);
}
self.extract_functions_classes(tree, code)
}
pub fn extract_functions_classes(
&self,
tree: Tree,
code: &String,
) -> LoaderResult<Vec<Document>> {
let mut chunks = Vec::new();
let count = tree.root_node().child_count();
for i in 0..count {
let Some(node) = tree.root_node().child(i) else {
continue;
};
let source_code = node.utf8_text(code.as_bytes())?.to_string();
let lang_meta = (
"language".to_string(),
serde_json::Value::from(self.parser_options.language.to_string()),
);
if node.kind() == "function_item" || node.kind() == "impl_item" {
let doc = Document::new(source_code).with_metadata(HashMap::from([
lang_meta.clone(),
(
"content_type".to_string(),
serde_json::Value::from(LanguageContentTypes::FunctionsImpls.to_string()),
),
]));
chunks.push(doc);
} else {
let doc = Document::new(source_code).with_metadata(HashMap::from([
lang_meta.clone(),
(
"content_type".to_string(),
serde_json::Value::from(LanguageContentTypes::SimplifiedCode.to_string()),
),
]));
chunks.push(doc);
}
}
Ok(chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_parser() {
let code = r#"
fn main() {
println!("Hello, world!");
}
pub struct Person {
name: String,
age: i32,
}
impl Person {
pub fn new(name: String, age: i32) -> Self {
Self { name, age }
}
pub fn get_name(&self) -> &str {
&self.name
}
pub fn get_age(&self) -> i32 {
self.age
}
}
"#;
let mut parser = LanguageParser::from_language(Language::Rust);
let documents = parser.parse_code(&code.to_string()).unwrap();
assert_eq!(documents.len(), 1);
// Set the parser threshold to 10 for testing
let mut parser = parser.with_parser_threshold(10);
let documents = parser.parse_code(&code.to_string()).unwrap();
assert_eq!(documents.len(), 3);
assert_eq!(
documents[0].page_content,
"fn main() {\n println!(\"Hello, world!\");\n }"
);
assert_eq!(
documents[1].metadata.get("content_type").unwrap(),
LanguageContentTypes::SimplifiedCode.to_string().as_str()
);
}
}

View File

@@ -0,0 +1,24 @@
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/document_loaders
*/
use super::*;
#[derive(Debug, Clone)]
pub struct TextLoader {
content: String,
}
impl TextLoader {
pub fn new<T: Into<String>>(input: T) -> Self {
Self {
content: input.into(),
}
}
}
impl Loader for TextLoader {
fn load(self) -> Result<Vec<Document>, LoaderError> {
let doc = Document::new(self.content);
Ok(vec![doc])
}
}

View File

@@ -0,0 +1,12 @@
mod document;
mod loader;
mod splitter;
mod types;
pub use document::{Chunk, Doc};
use loader::{
get_language_by_filename, DocxLoader, HtmlLoader, LanguageParserOptions, Loader, LoaderError,
PdfExtractLoader, SourceCodeLoader, TextLoader, Url,
};
use splitter::{MarkdownSplitter, TextSplitter, TextSplitterError, TokenSplitter};
use types::Document;

View File

@@ -0,0 +1,35 @@
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/text_splitter
*/
use text_splitter::ChunkConfigError;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TextSplitterError {
#[error("Empty input text")]
EmptyInputText,
#[error("Mismatch metadata and text")]
MetadataTextMismatch,
#[error("Tokenizer not found")]
TokenizerNotFound,
#[error("Tokenizer creation failed due to invalid tokenizer")]
InvalidTokenizer,
#[error("Tokenizer creation failed due to invalid model")]
InvalidModel,
#[error("Invalid chunk overlap and size")]
InvalidSplitterOptions,
#[error("Error: {0}")]
OtherError(String),
}
impl From<ChunkConfigError> for TextSplitterError {
fn from(_: ChunkConfigError) -> Self {
Self::InvalidSplitterOptions
}
}

View File

@@ -0,0 +1,36 @@
use text_splitter::ChunkConfig;
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/text_splitter
*/
use super::*;
pub struct MarkdownSplitter {
splitter_options: SplitterOptions,
}
impl Default for MarkdownSplitter {
fn default() -> Self {
MarkdownSplitter::new(SplitterOptions::default())
}
}
impl MarkdownSplitter {
pub fn new(options: SplitterOptions) -> MarkdownSplitter {
MarkdownSplitter {
splitter_options: options,
}
}
}
impl TextSplitter for MarkdownSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError> {
let chunk_config = ChunkConfig::try_from(&self.splitter_options)?;
Ok(
text_splitter::MarkdownSplitter::new(chunk_config)
.chunks(text)
.map(|x| x.to_string())
.collect(),
)
}
}

View File

@@ -0,0 +1,58 @@
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/text_splitter
*/
mod error;
mod markdown;
mod options;
mod token;
use std::collections::HashMap;
pub use error::TextSplitterError;
pub use markdown::MarkdownSplitter;
use options::SplitterOptions;
use serde_json::Value;
pub use token::TokenSplitter;
use super::*;
pub trait TextSplitter: Send + Sync {
fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError>;
fn split_documents(&self, documents: &[Document]) -> Result<Vec<Document>, TextSplitterError> {
let mut texts: Vec<String> = Vec::new();
let mut metadatas: Vec<HashMap<String, Value>> = Vec::new();
documents.iter().for_each(|d| {
texts.push(d.page_content.clone());
metadatas.push(d.metadata.clone());
});
self.create_documents(&texts, &metadatas)
}
fn create_documents(
&self,
text: &[String],
metadatas: &[HashMap<String, Value>],
) -> Result<Vec<Document>, TextSplitterError> {
let mut metadatas = metadatas.to_vec();
if metadatas.is_empty() {
metadatas = vec![HashMap::new(); text.len()];
}
if text.len() != metadatas.len() {
return Err(TextSplitterError::MetadataTextMismatch);
}
let mut documents: Vec<Document> = Vec::new();
for i in 0..text.len() {
let chunks = self.split_text(&text[i])?;
for chunk in chunks {
let document = Document::new(chunk).with_metadata(metadatas[i].clone());
documents.push(document);
}
}
Ok(documents)
}
}

View File

@@ -0,0 +1,96 @@
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/text_splitter
*/
use text_splitter::ChunkConfig;
use tiktoken_rs::{get_bpe_from_model, get_bpe_from_tokenizer, tokenizer::Tokenizer, CoreBPE};
use super::TextSplitterError;
// Options is a struct that contains options for a text splitter.
#[derive(Debug, Clone)]
pub struct SplitterOptions {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub model_name: String,
pub encoding_name: String,
pub trim_chunks: bool,
}
impl Default for SplitterOptions {
fn default() -> Self {
Self::new()
}
}
impl SplitterOptions {
pub fn new() -> Self {
SplitterOptions {
chunk_size: 512,
chunk_overlap: 0,
model_name: String::from("gpt-3.5-turbo"),
encoding_name: String::from("cl100k_base"),
trim_chunks: false,
}
}
}
// Builder pattern for Options struct
impl SplitterOptions {
pub fn with_chunk_size(mut self, chunk_size: usize) -> Self {
self.chunk_size = chunk_size;
self
}
pub fn with_chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.chunk_overlap = chunk_overlap;
self
}
pub fn with_model_name(mut self, model_name: &str) -> Self {
self.model_name = String::from(model_name);
self
}
pub fn with_encoding_name(mut self, encoding_name: &str) -> Self {
self.encoding_name = String::from(encoding_name);
self
}
pub fn with_trim_chunks(mut self, trim_chunks: bool) -> Self {
self.trim_chunks = trim_chunks;
self
}
pub fn get_tokenizer_from_str(s: &str) -> Option<Tokenizer> {
match s.to_lowercase().as_str() {
"cl100k_base" => Some(Tokenizer::Cl100kBase),
"p50k_base" => Some(Tokenizer::P50kBase),
"r50k_base" => Some(Tokenizer::R50kBase),
"p50k_edit" => Some(Tokenizer::P50kEdit),
"gpt2" => Some(Tokenizer::Gpt2),
_ => None,
}
}
}
impl TryFrom<&SplitterOptions> for ChunkConfig<CoreBPE> {
type Error = TextSplitterError;
fn try_from(options: &SplitterOptions) -> Result<Self, Self::Error> {
let tk = if !options.encoding_name.is_empty() {
let tokenizer = SplitterOptions::get_tokenizer_from_str(&options.encoding_name)
.ok_or(TextSplitterError::TokenizerNotFound)?;
get_bpe_from_tokenizer(tokenizer).map_err(|_| TextSplitterError::InvalidTokenizer)?
} else {
get_bpe_from_model(&options.model_name).map_err(|_| TextSplitterError::InvalidModel)?
};
Ok(
ChunkConfig::new(options.chunk_size)
.with_sizer(tk)
.with_trim(options.trim_chunks)
.with_overlap(options.chunk_overlap)?,
)
}
}

View File

@@ -0,0 +1,37 @@
use text_splitter::ChunkConfig;
/**
* modified from https://github.com/Abraxas-365/langchain-rust/tree/v4.6.0/src/text_splitter
*/
use super::*;
#[derive(Debug, Clone)]
pub struct TokenSplitter {
splitter_options: SplitterOptions,
}
impl Default for TokenSplitter {
fn default() -> Self {
TokenSplitter::new(SplitterOptions::default())
}
}
impl TokenSplitter {
pub fn new(options: SplitterOptions) -> TokenSplitter {
TokenSplitter {
splitter_options: options,
}
}
}
impl TextSplitter for TokenSplitter {
fn split_text(&self, text: &str) -> Result<Vec<String>, TextSplitterError> {
let chunk_config = ChunkConfig::try_from(&self.splitter_options)?;
Ok(
text_splitter::TextSplitter::new(chunk_config)
.chunks(text)
.map(|x| x.to_string())
.collect(),
)
}
}

View File

@@ -0,0 +1,37 @@
use std::collections::HashMap;
use serde_json::Value;
#[derive(Debug, Clone)]
pub struct Document {
pub page_content: String,
pub metadata: HashMap<String, Value>,
}
impl Document {
/// Constructs a new `Document` with provided `page_content`, an empty
/// `metadata` map and a `score` of 0.
pub fn new<S: Into<String>>(page_content: S) -> Self {
Document {
page_content: page_content.into(),
metadata: HashMap::new(),
}
}
/// Sets the `metadata` Map of the `Document` to the provided HashMap.
pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
self.metadata = metadata;
self
}
}
impl Default for Document {
/// Provides a default `Document` with an empty `page_content`, an empty
/// `metadata` map and a `score` of 0.
fn default() -> Self {
Document {
page_content: "".to_string(),
metadata: HashMap::new(),
}
}
}

View File

@@ -1 +1,3 @@
#[cfg(feature = "doc-loader")]
pub mod doc_loader;
pub mod hashcash;