new db and readability

This commit is contained in:
MatthieuCoder 2023-01-22 01:55:20 +04:00
parent acdcd9bbcb
commit 9f798c57f8
9 changed files with 142 additions and 273 deletions

View file

@ -1,6 +1,5 @@
use std::fs;
use autofeur::french_ipa::parse_word;
use autofeur::save::Save;
use kdam::tqdm;
@ -42,18 +41,10 @@ async fn main() {
phonems.append(&mut pron);
}
let mut invalid = 0;
for phoneme in tqdm!(phonems.iter()) {
match parse_word(&phoneme) {
Some(a) => save.trie.insert(a),
None => {
invalid += 1;
}
}
save.trie.insert(&phoneme);
}
println!("Invalid items count: {}", invalid);
fs::write("assets/db.bin", bincode::serialize(&save).unwrap()).unwrap();
println!("Generated to assets/db.bin");

View file

@ -17,28 +17,47 @@ fn parse_query(query: &str) -> HashMap<String, String> {
.collect()
}
async fn handler(request: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
fn anyhow_response(err: anyhow::Error) -> Response<Body> {
Response::builder()
.status(400)
.body(Body::from(err.root_cause().to_string()))
.unwrap()
}
async fn handler(request: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let save: &Arc<Save> = request.extensions().get().unwrap();
let query = request
let query = match request
.uri()
.query()
.ok_or_else(|| anyhow!("query does not exists"))?;
let data = parse_query(query)
.ok_or_else(|| anyhow_response(anyhow!("query does not exists")))
{
Ok(ok) => ok,
Err(err) => return Ok(err),
};
let data = match parse_query(query)
.get("grapheme")
.ok_or_else(|| anyhow!("grapheme argument is not specified"))?
.clone();
.ok_or_else(|| anyhow_response(anyhow!("grapheme argument is not specified")))
{
Ok(ok) => ok.clone(),
Err(err) => return Ok(err),
};
let infered = save
let infered = match save
.inference(&data)
.await
.or_else(|_| Err(anyhow!("cannot find data")))?;
.or_else(|e| Err(anyhow_response(e.context("inference error"))))
{
Ok(ok) => ok,
Err(e) => return Ok(e),
};
Ok(Response::builder().body(Body::from(infered)).unwrap())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let checkpoint: Save = bincode::deserialize(&fs::read("assets/db.bin").unwrap()).unwrap();
let data = Box::leak(Box::new(fs::read("assets/db.bin").unwrap()));
let checkpoint: Save = bincode::deserialize(data).unwrap();
let service = ServiceBuilder::new()
.layer(AddExtensionLayer::new(Arc::new(checkpoint)))
// Wrap a `Service` in our middleware stack

View file

@ -1,129 +0,0 @@
use std::hash::Hash;
use unicode_segmentation::UnicodeSegmentation;
macro_rules! ipa_element_to_number {
(@step $_idx:expr, $ident:ident,) => {
None
};
(@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
if $ident == $head {
Some(Self($idx))
}
else {
ipa_element_to_number!(@step $idx + 1usize, $ident, $($tail,)*)
}
};
}
macro_rules! ipa_number_to_ipa {
(@step $_idx:expr, $ident:ident,) => {
"unreachable!()"
};
(@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
if $ident == $idx {
$head
}
else {
ipa_number_to_ipa!(@step $idx + 1usize, $ident, $($tail,)*)
}
};
}
macro_rules! replace_expr {
($_t:tt $sub:expr) => {
$sub
};
}
macro_rules! count_tts {
($($tts:tt)*) => {0usize $(+ replace_expr!($tts 1usize))*};
}
macro_rules! ipa_map {
($name:ident, $($l:literal),*) => {
use serde::{Deserialize, Serialize};
#[derive(Eq, Hash, PartialEq, Debug, Copy, Clone, Serialize, Deserialize)]
pub struct $name(pub usize);
impl $name {
pub const SIZE: usize = count_tts!($($l,)*);
pub fn from_char(ch: &str) -> Option<$name> {
ipa_element_to_number!(@step 0usize, ch, $($l,)*)
}
pub fn to_char(self) -> &'static str {
let num = self.0;
ipa_number_to_ipa!(@step 0usize, num, $($l,)*)
}
}
};
}
ipa_map!(
FrenchIPAChar,
"a",
"ɑ",
"ɑ̃",
"e",
"ɛ",
"ɛ̃",
"ə",
"i",
"o",
"ɔ",
"ɔ̃",
"œ",
"œ̃",
"ø",
"u",
"y",
"j",
"ɥ",
"w",
"b",
"d",
"f",
"g",
"k",
"l",
"m",
"n",
"ɲ",
"ŋ",
"p",
"ʁ",
"s",
"ʃ",
"t",
"v",
"z",
"ʒ",
"g",
"ɡ",
"ɪ",
"ʊ",
"x",
"r"
);
pub type FrenchIPAWord = Vec<FrenchIPAChar>;
pub fn parse_word(str: &str) -> Option<FrenchIPAWord> {
let mut word = FrenchIPAWord::default();
let graphemes: Vec<&str> = str.graphemes(true).collect();
for (_, grapheme) in graphemes.iter().enumerate() {
let a = FrenchIPAChar::from_char(grapheme);
word.push(match a {
None => {
println!("invalid char: {}", grapheme);
return None;
}
Some(a) => a,
})
}
Some(word)
}

View file

@ -5,11 +5,11 @@ use itertools::Itertools;
use levenshtein::levenshtein;
use unicode_segmentation::UnicodeSegmentation;
use crate::{french_ipa::parse_word, save::Save};
use crate::save::Save;
async fn call_inference_service(word: &str) -> anyhow::Result<String> {
let server: Result<String, anyhow::Error> =
env::var("PHONEMIZER").or_else(|_| Ok("".to_string()));
env::var("PHONEMIZER").or_else(|_| Ok("http://localhost:8000/".to_string()));
Ok(
reqwest::get(format!("{}?grapheme={}", server.unwrap(), word))
.await?
@ -18,22 +18,23 @@ async fn call_inference_service(word: &str) -> anyhow::Result<String> {
)
}
impl Save {
impl Save<'_> {
pub async fn inference(&self, prefix: &str) -> anyhow::Result<String> {
let phonemes = call_inference_service(prefix).await?;
let ipa_phonemes =
parse_word(&phonemes).ok_or_else(|| anyhow!("failed to parse the word"))?;
let completion = self
.trie
.random_starting_with(ipa_phonemes)
.random_starting_with(&phonemes)
.ok_or_else(|| anyhow!("no matches"))?;
let infered = phonemes.add(&completion);
let word = self
.reverse_index
.get(&infered)
.ok_or_else(|| anyhow!("matched values is not in dictionary"))?;
let infered = phonemes.clone().add(&completion);
let word = self.reverse_index.get(&infered).ok_or_else(|| {
anyhow!(
"matched value is not in dictionary {} {}",
infered,
phonemes
)
})?;
println!("Matching {} by adding {}", word, completion);

View file

@ -1,4 +1,3 @@
pub mod trie;
pub mod french_ipa;
pub mod save;
pub mod inference;

View file

@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize};
use crate::trie::Trie;
#[derive(Debug, Deserialize, Serialize, Default)]
pub struct Save {
pub trie: Trie,
pub reverse_index: HashMap<String, String>
}
pub struct Save<'a> {
#[serde(borrow = "'a")]
pub trie: Trie<'a>,
pub reverse_index: HashMap<String, String>,
}

View file

@ -1,169 +1,150 @@
use std::collections::HashMap;
use rand::{thread_rng, Rng};
use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
use serde::{Deserialize, Serialize};
use crate::french_ipa::{FrenchIPAChar, FrenchIPAWord};
use unicode_segmentation::UnicodeSegmentation;
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct TrieNode {
value: Option<FrenchIPAChar>,
pub struct TrieNode<'a> {
is_final: bool,
child_nodes: HashMap<FrenchIPAChar, TrieNode>,
#[serde(borrow = "'a")]
child_nodes: HashMap<&'a str, TrieNode<'a>>,
child_count: u64,
}
impl TrieNode {
impl<'a> TrieNode<'a> {
// Create new node
pub fn new(c: FrenchIPAChar, is_final: bool) -> TrieNode {
pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
TrieNode {
value: Option::Some(c),
is_final,
child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
child_nodes: HashMap::new(),
child_count: 0,
}
}
pub fn new_root() -> TrieNode {
pub fn new_root<'b>() -> TrieNode<'b> {
TrieNode {
value: Option::None,
is_final: false,
child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
child_nodes: HashMap::new(),
child_count: 0,
}
}
}
#[derive(Debug, Serialize, Deserialize, Default)]
pub struct Trie {
root_node: Box<TrieNode>,
pub struct Trie<'a> {
#[serde(borrow = "'a")]
root_node: Box<TrieNode<'a>>,
}
impl Trie {
impl<'a> Trie<'a> {
// Create a TrieStruct
pub fn new() -> Trie {
pub fn new<'b>() -> Trie<'b> {
Trie {
root_node: Box::new(TrieNode::new_root()),
}
}
// Insert a string
pub fn insert(&mut self, char_list: FrenchIPAWord) {
pub fn insert(&mut self, char_list: &'a str) {
let mut current_node: &mut TrieNode = self.root_node.as_mut();
let mut last_match = 0;
let iterator = char_list.graphemes(true);
let mut create = false;
current_node.child_count += 1;
// Find the minimum existing math
for letter_counter in 0..char_list.len() {
if current_node
.child_nodes
.contains_key(&char_list[letter_counter])
{
current_node = current_node
.child_nodes
.get_mut(&char_list[letter_counter])
.unwrap();
// we mark the node as containing our children.
current_node.child_count += 1;
} else {
last_match = letter_counter;
break;
}
last_match = letter_counter + 1;
}
for str in iterator {
if create == false {
if current_node.child_nodes.contains_key(str) {
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we mark the node as containing our children.
current_node.child_count += 1;
} else {
create = true;
// if we found an already exsting node
if last_match == char_list.len() {
current_node.is_final = true;
} else {
for new_counter in last_match..char_list.len() {
let key = char_list[new_counter];
current_node
.child_nodes
.insert(key, TrieNode::new(char_list[new_counter], false));
current_node = current_node.child_nodes.get_mut(&key).unwrap();
current_node.child_count += 1;
current_node.child_nodes.insert(str, TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
current_node.child_count = 1;
}
} else {
current_node.child_nodes.insert(str, TrieNode::new(false));
current_node = current_node.child_nodes.get_mut(str).unwrap();
// we will only have one final node
current_node.child_count = 1;
}
current_node.is_final = true;
}
current_node.is_final = true;
}
// Find a string
pub fn random_starting_with(&self, prefix: FrenchIPAWord) -> Option<String> {
pub fn random_starting_with(&self, prefix: &str) -> Option<String> {
let mut current_node: &TrieNode = self.root_node.as_ref();
let mut str = String::new();
let mut i = prefix.len();
// String for the return value
let mut builder = String::new();
// Iterator over each grapheme
let graphemes = prefix.graphemes(true).enumerate();
// Descend as far as possible into the tree
for counter in prefix {
if let Some(node) = current_node.child_nodes.get(&counter) {
for (_, str) in graphemes {
// If we can descend further into the tree
if let Some(node) = current_node.child_nodes.get(&str) {
current_node = node;
if let Some(value) = current_node.value {
str += value.to_char();
i -= 1;
}
builder += str;
println!("going into node {}", builder);
} else {
// couldn't descend fully into the tree
// this basically means nothing exist in the tree
// with this prefix
println!("no matches for prefix!");
return None;
}
}
println!("Found common root node {}", str);
println!("Found common root node {}", builder);
builder = String::new();
let mut rng = thread_rng();
// Ignore the 0-len matches
if i == 0 && current_node.child_nodes.len() == 0 {
println!("removing 0-len match");
return None;
}
str = String::new();
while current_node.child_nodes.len() != 0 {
// We need to choose a random child based on weights
let weighted = WeightedIndex::new(
current_node
.child_nodes
.iter()
.map(|(_, node)| node.child_count),
)
.expect("distribution creation should be valid");
// now that we have the node we descend by respecting the probabilities
while current_node.child_nodes.len() != 0 && current_node.child_count > 0 {
println!("Descending into node {}", str);
let max = current_node.child_count;
let random_number = thread_rng().gen_range(0..max);
let mut increment = 0;
let (key, node) = current_node
.child_nodes
.iter()
.nth(weighted.sample(&mut rng))
.expect("choosed value did not exist");
println!("waling into node {}", key);
let mut did_change = false;
// find node corresponding to the node
for (_, node) in &current_node.child_nodes {
if node.child_count + increment >= random_number {
println!("changing node");
current_node = node;
did_change = true;
break;
} else {
println!(
"didn't change node: {}<{}",
node.child_count + increment,
random_number
)
}
increment += node.child_count;
}
if did_change {
if let Some(value) = current_node.value {
println!("added {}", value.to_char());
str += value.to_char();
}
} else {
println!(
"WARNING: DIDNT CHANGE NODE child_count={}",
current_node.child_count
)
}
// if this node is a final node, we have a probability of using it
current_node = node;
builder += key;
// If this node is final and has childrens
if current_node.is_final && current_node.child_count > 0 {
let random_number = thread_rng().gen_range(0..current_node.child_count);
if random_number == 0 {
// choose from current node or continue with childrens
let weighted = WeightedIndex::new(&[1, current_node.child_count])
.expect("distribution seems impossible");
if weighted.sample(&mut rng) == 0 {
// we choosed this node!
// stop adding other characters
break;
}
}
}
if str == "" {
// If we only added
if builder == "" {
return None;
}
// selected word
Some(str)
Some(builder)
}
}

View file

@ -10,19 +10,23 @@ import { request } from "undici";
// `autofeur_db` service
export const DB = process.env.DB || "http://localhost:3000";
// nats broker for connecting to nova
export const NATS = process.env.NATS || "localhost:4222";
export const NATS = process.env.NATS || "192.168.0.17:4222";
// rest endpoint for connecting to nova
export const REST = process.env.REST || "http://localhost:8090/api";
export const REST = process.env.REST || "http://192.168.0.17:8090/api";
/**
* Completes a grapheme using the `autofeur_db` service.
* @param grapheme Grapheme to complete
* @returns Completed grapheme
*/
export const completeWord = (grapheme: string) =>
request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`).then((x) =>
x.body.text()
);
export const completeWord = (grapheme: string) => {
let resp = request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`);
return resp.then((x) => {
if (x.statusCode === 200) {
return x.body.text();
}
});
};
/**
* Cleans a sentence for usage with this program, strips unwanted chars

View file

@ -23,6 +23,8 @@ services:
deep_phonemizer:
build: deep_phonemizer
restart: always
ports:
- 8000:8000
volumes:
- ./deep_phonemizer/assets/model.pt:/app/assets/model.pt
nats: