new db and readability
This commit is contained in:
parent
acdcd9bbcb
commit
9f798c57f8
|
@ -1,6 +1,5 @@
|
|||
use std::fs;
|
||||
|
||||
use autofeur::french_ipa::parse_word;
|
||||
use autofeur::save::Save;
|
||||
use kdam::tqdm;
|
||||
|
||||
|
@ -42,18 +41,10 @@ async fn main() {
|
|||
phonems.append(&mut pron);
|
||||
}
|
||||
|
||||
let mut invalid = 0;
|
||||
for phoneme in tqdm!(phonems.iter()) {
|
||||
match parse_word(&phoneme) {
|
||||
Some(a) => save.trie.insert(a),
|
||||
None => {
|
||||
invalid += 1;
|
||||
}
|
||||
}
|
||||
save.trie.insert(&phoneme);
|
||||
}
|
||||
|
||||
println!("Invalid items count: {}", invalid);
|
||||
|
||||
fs::write("assets/db.bin", bincode::serialize(&save).unwrap()).unwrap();
|
||||
|
||||
println!("Generated to assets/db.bin");
|
||||
|
|
|
@ -17,28 +17,47 @@ fn parse_query(query: &str) -> HashMap<String, String> {
|
|||
.collect()
|
||||
}
|
||||
|
||||
async fn handler(request: Request<Body>) -> Result<Response<Body>, anyhow::Error> {
|
||||
fn anyhow_response(err: anyhow::Error) -> Response<Body> {
|
||||
Response::builder()
|
||||
.status(400)
|
||||
.body(Body::from(err.root_cause().to_string()))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn handler(request: Request<Body>) -> Result<Response<Body>, hyper::Error> {
|
||||
let save: &Arc<Save> = request.extensions().get().unwrap();
|
||||
let query = request
|
||||
let query = match request
|
||||
.uri()
|
||||
.query()
|
||||
.ok_or_else(|| anyhow!("query does not exists"))?;
|
||||
let data = parse_query(query)
|
||||
.ok_or_else(|| anyhow_response(anyhow!("query does not exists")))
|
||||
{
|
||||
Ok(ok) => ok,
|
||||
Err(err) => return Ok(err),
|
||||
};
|
||||
let data = match parse_query(query)
|
||||
.get("grapheme")
|
||||
.ok_or_else(|| anyhow!("grapheme argument is not specified"))?
|
||||
.clone();
|
||||
.ok_or_else(|| anyhow_response(anyhow!("grapheme argument is not specified")))
|
||||
{
|
||||
Ok(ok) => ok.clone(),
|
||||
Err(err) => return Ok(err),
|
||||
};
|
||||
|
||||
let infered = save
|
||||
let infered = match save
|
||||
.inference(&data)
|
||||
.await
|
||||
.or_else(|_| Err(anyhow!("cannot find data")))?;
|
||||
.or_else(|e| Err(anyhow_response(e.context("inference error"))))
|
||||
{
|
||||
Ok(ok) => ok,
|
||||
Err(e) => return Ok(e),
|
||||
};
|
||||
|
||||
Ok(Response::builder().body(Body::from(infered)).unwrap())
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let checkpoint: Save = bincode::deserialize(&fs::read("assets/db.bin").unwrap()).unwrap();
|
||||
let data = Box::leak(Box::new(fs::read("assets/db.bin").unwrap()));
|
||||
let checkpoint: Save = bincode::deserialize(data).unwrap();
|
||||
let service = ServiceBuilder::new()
|
||||
.layer(AddExtensionLayer::new(Arc::new(checkpoint)))
|
||||
// Wrap a `Service` in our middleware stack
|
||||
|
|
|
@ -1,129 +0,0 @@
|
|||
use std::hash::Hash;
|
||||
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
macro_rules! ipa_element_to_number {
|
||||
(@step $_idx:expr, $ident:ident,) => {
|
||||
None
|
||||
};
|
||||
|
||||
(@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
|
||||
if $ident == $head {
|
||||
Some(Self($idx))
|
||||
}
|
||||
else {
|
||||
ipa_element_to_number!(@step $idx + 1usize, $ident, $($tail,)*)
|
||||
}
|
||||
};
|
||||
}
|
||||
macro_rules! ipa_number_to_ipa {
|
||||
(@step $_idx:expr, $ident:ident,) => {
|
||||
"unreachable!()"
|
||||
};
|
||||
|
||||
(@step $idx:expr, $ident:ident, $head:literal, $($tail:literal,)*) => {
|
||||
if $ident == $idx {
|
||||
$head
|
||||
}
|
||||
else {
|
||||
ipa_number_to_ipa!(@step $idx + 1usize, $ident, $($tail,)*)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! replace_expr {
|
||||
($_t:tt $sub:expr) => {
|
||||
$sub
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! count_tts {
|
||||
($($tts:tt)*) => {0usize $(+ replace_expr!($tts 1usize))*};
|
||||
}
|
||||
|
||||
macro_rules! ipa_map {
|
||||
($name:ident, $($l:literal),*) => {
|
||||
use serde::{Deserialize, Serialize};
|
||||
#[derive(Eq, Hash, PartialEq, Debug, Copy, Clone, Serialize, Deserialize)]
|
||||
pub struct $name(pub usize);
|
||||
|
||||
impl $name {
|
||||
pub const SIZE: usize = count_tts!($($l,)*);
|
||||
|
||||
pub fn from_char(ch: &str) -> Option<$name> {
|
||||
ipa_element_to_number!(@step 0usize, ch, $($l,)*)
|
||||
}
|
||||
pub fn to_char(self) -> &'static str {
|
||||
let num = self.0;
|
||||
ipa_number_to_ipa!(@step 0usize, num, $($l,)*)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ipa_map!(
|
||||
FrenchIPAChar,
|
||||
"a",
|
||||
"ɑ",
|
||||
"ɑ̃",
|
||||
"e",
|
||||
"ɛ",
|
||||
"ɛ̃",
|
||||
"ə",
|
||||
"i",
|
||||
"o",
|
||||
"ɔ",
|
||||
"ɔ̃",
|
||||
"œ",
|
||||
"œ̃",
|
||||
"ø",
|
||||
"u",
|
||||
"y",
|
||||
"j",
|
||||
"ɥ",
|
||||
"w",
|
||||
"b",
|
||||
"d",
|
||||
"f",
|
||||
"g",
|
||||
"k",
|
||||
"l",
|
||||
"m",
|
||||
"n",
|
||||
"ɲ",
|
||||
"ŋ",
|
||||
"p",
|
||||
"ʁ",
|
||||
"s",
|
||||
"ʃ",
|
||||
"t",
|
||||
"v",
|
||||
"z",
|
||||
"ʒ",
|
||||
"g",
|
||||
"ɡ",
|
||||
"ɪ",
|
||||
"ʊ",
|
||||
"x",
|
||||
"r"
|
||||
);
|
||||
|
||||
pub type FrenchIPAWord = Vec<FrenchIPAChar>;
|
||||
|
||||
pub fn parse_word(str: &str) -> Option<FrenchIPAWord> {
|
||||
let mut word = FrenchIPAWord::default();
|
||||
let graphemes: Vec<&str> = str.graphemes(true).collect();
|
||||
for (_, grapheme) in graphemes.iter().enumerate() {
|
||||
let a = FrenchIPAChar::from_char(grapheme);
|
||||
|
||||
word.push(match a {
|
||||
None => {
|
||||
println!("invalid char: {}", grapheme);
|
||||
return None;
|
||||
}
|
||||
Some(a) => a,
|
||||
})
|
||||
}
|
||||
|
||||
Some(word)
|
||||
}
|
|
@ -5,11 +5,11 @@ use itertools::Itertools;
|
|||
use levenshtein::levenshtein;
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
use crate::{french_ipa::parse_word, save::Save};
|
||||
use crate::save::Save;
|
||||
|
||||
async fn call_inference_service(word: &str) -> anyhow::Result<String> {
|
||||
let server: Result<String, anyhow::Error> =
|
||||
env::var("PHONEMIZER").or_else(|_| Ok("".to_string()));
|
||||
env::var("PHONEMIZER").or_else(|_| Ok("http://localhost:8000/".to_string()));
|
||||
Ok(
|
||||
reqwest::get(format!("{}?grapheme={}", server.unwrap(), word))
|
||||
.await?
|
||||
|
@ -18,22 +18,23 @@ async fn call_inference_service(word: &str) -> anyhow::Result<String> {
|
|||
)
|
||||
}
|
||||
|
||||
impl Save {
|
||||
impl Save<'_> {
|
||||
pub async fn inference(&self, prefix: &str) -> anyhow::Result<String> {
|
||||
let phonemes = call_inference_service(prefix).await?;
|
||||
let ipa_phonemes =
|
||||
parse_word(&phonemes).ok_or_else(|| anyhow!("failed to parse the word"))?;
|
||||
|
||||
let completion = self
|
||||
.trie
|
||||
.random_starting_with(ipa_phonemes)
|
||||
.random_starting_with(&phonemes)
|
||||
.ok_or_else(|| anyhow!("no matches"))?;
|
||||
|
||||
let infered = phonemes.add(&completion);
|
||||
let word = self
|
||||
.reverse_index
|
||||
.get(&infered)
|
||||
.ok_or_else(|| anyhow!("matched values is not in dictionary"))?;
|
||||
let infered = phonemes.clone().add(&completion);
|
||||
let word = self.reverse_index.get(&infered).ok_or_else(|| {
|
||||
anyhow!(
|
||||
"matched value is not in dictionary {} {}",
|
||||
infered,
|
||||
phonemes
|
||||
)
|
||||
})?;
|
||||
|
||||
println!("Matching {} by adding {}", word, completion);
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
pub mod trie;
|
||||
pub mod french_ipa;
|
||||
pub mod save;
|
||||
pub mod inference;
|
||||
|
|
|
@ -5,7 +5,8 @@ use serde::{Deserialize, Serialize};
|
|||
use crate::trie::Trie;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Default)]
|
||||
pub struct Save {
|
||||
pub trie: Trie,
|
||||
pub reverse_index: HashMap<String, String>
|
||||
}
|
||||
pub struct Save<'a> {
|
||||
#[serde(borrow = "'a")]
|
||||
pub trie: Trie<'a>,
|
||||
pub reverse_index: HashMap<String, String>,
|
||||
}
|
||||
|
|
|
@ -1,169 +1,150 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::{distributions::WeightedIndex, prelude::Distribution, thread_rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::french_ipa::{FrenchIPAChar, FrenchIPAWord};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct TrieNode {
|
||||
value: Option<FrenchIPAChar>,
|
||||
pub struct TrieNode<'a> {
|
||||
is_final: bool,
|
||||
child_nodes: HashMap<FrenchIPAChar, TrieNode>,
|
||||
#[serde(borrow = "'a")]
|
||||
child_nodes: HashMap<&'a str, TrieNode<'a>>,
|
||||
child_count: u64,
|
||||
}
|
||||
|
||||
impl TrieNode {
|
||||
impl<'a> TrieNode<'a> {
|
||||
// Create new node
|
||||
pub fn new(c: FrenchIPAChar, is_final: bool) -> TrieNode {
|
||||
pub fn new<'b>(is_final: bool) -> TrieNode<'b> {
|
||||
TrieNode {
|
||||
value: Option::Some(c),
|
||||
is_final,
|
||||
child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
|
||||
child_nodes: HashMap::new(),
|
||||
child_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_root() -> TrieNode {
|
||||
pub fn new_root<'b>() -> TrieNode<'b> {
|
||||
TrieNode {
|
||||
value: Option::None,
|
||||
is_final: false,
|
||||
child_nodes: HashMap::with_capacity(FrenchIPAChar::SIZE),
|
||||
child_nodes: HashMap::new(),
|
||||
child_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct Trie {
|
||||
root_node: Box<TrieNode>,
|
||||
pub struct Trie<'a> {
|
||||
#[serde(borrow = "'a")]
|
||||
root_node: Box<TrieNode<'a>>,
|
||||
}
|
||||
|
||||
impl Trie {
|
||||
impl<'a> Trie<'a> {
|
||||
// Create a TrieStruct
|
||||
pub fn new() -> Trie {
|
||||
pub fn new<'b>() -> Trie<'b> {
|
||||
Trie {
|
||||
root_node: Box::new(TrieNode::new_root()),
|
||||
}
|
||||
}
|
||||
|
||||
// Insert a string
|
||||
pub fn insert(&mut self, char_list: FrenchIPAWord) {
|
||||
pub fn insert(&mut self, char_list: &'a str) {
|
||||
let mut current_node: &mut TrieNode = self.root_node.as_mut();
|
||||
let mut last_match = 0;
|
||||
let iterator = char_list.graphemes(true);
|
||||
let mut create = false;
|
||||
|
||||
current_node.child_count += 1;
|
||||
// Find the minimum existing math
|
||||
for letter_counter in 0..char_list.len() {
|
||||
if current_node
|
||||
.child_nodes
|
||||
.contains_key(&char_list[letter_counter])
|
||||
{
|
||||
current_node = current_node
|
||||
.child_nodes
|
||||
.get_mut(&char_list[letter_counter])
|
||||
.unwrap();
|
||||
// we mark the node as containing our children.
|
||||
current_node.child_count += 1;
|
||||
} else {
|
||||
last_match = letter_counter;
|
||||
break;
|
||||
}
|
||||
last_match = letter_counter + 1;
|
||||
}
|
||||
for str in iterator {
|
||||
if create == false {
|
||||
if current_node.child_nodes.contains_key(str) {
|
||||
current_node = current_node.child_nodes.get_mut(str).unwrap();
|
||||
// we mark the node as containing our children.
|
||||
current_node.child_count += 1;
|
||||
} else {
|
||||
create = true;
|
||||
|
||||
// if we found an already exsting node
|
||||
if last_match == char_list.len() {
|
||||
current_node.is_final = true;
|
||||
} else {
|
||||
for new_counter in last_match..char_list.len() {
|
||||
let key = char_list[new_counter];
|
||||
current_node
|
||||
.child_nodes
|
||||
.insert(key, TrieNode::new(char_list[new_counter], false));
|
||||
current_node = current_node.child_nodes.get_mut(&key).unwrap();
|
||||
current_node.child_count += 1;
|
||||
current_node.child_nodes.insert(str, TrieNode::new(false));
|
||||
current_node = current_node.child_nodes.get_mut(str).unwrap();
|
||||
current_node.child_count = 1;
|
||||
}
|
||||
} else {
|
||||
current_node.child_nodes.insert(str, TrieNode::new(false));
|
||||
current_node = current_node.child_nodes.get_mut(str).unwrap();
|
||||
// we will only have one final node
|
||||
current_node.child_count = 1;
|
||||
}
|
||||
current_node.is_final = true;
|
||||
}
|
||||
current_node.is_final = true;
|
||||
}
|
||||
|
||||
// Find a string
|
||||
pub fn random_starting_with(&self, prefix: FrenchIPAWord) -> Option<String> {
|
||||
pub fn random_starting_with(&self, prefix: &str) -> Option<String> {
|
||||
let mut current_node: &TrieNode = self.root_node.as_ref();
|
||||
let mut str = String::new();
|
||||
let mut i = prefix.len();
|
||||
// String for the return value
|
||||
let mut builder = String::new();
|
||||
|
||||
// Iterator over each grapheme
|
||||
let graphemes = prefix.graphemes(true).enumerate();
|
||||
|
||||
// Descend as far as possible into the tree
|
||||
for counter in prefix {
|
||||
if let Some(node) = current_node.child_nodes.get(&counter) {
|
||||
for (_, str) in graphemes {
|
||||
// If we can descend further into the tree
|
||||
if let Some(node) = current_node.child_nodes.get(&str) {
|
||||
current_node = node;
|
||||
if let Some(value) = current_node.value {
|
||||
str += value.to_char();
|
||||
i -= 1;
|
||||
}
|
||||
builder += str;
|
||||
println!("going into node {}", builder);
|
||||
} else {
|
||||
// couldn't descend fully into the tree
|
||||
// this basically means nothing exist in the tree
|
||||
// with this prefix
|
||||
println!("no matches for prefix!");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
println!("Found common root node {}", str);
|
||||
println!("Found common root node {}", builder);
|
||||
builder = String::new();
|
||||
let mut rng = thread_rng();
|
||||
|
||||
// Ignore the 0-len matches
|
||||
if i == 0 && current_node.child_nodes.len() == 0 {
|
||||
println!("removing 0-len match");
|
||||
return None;
|
||||
}
|
||||
str = String::new();
|
||||
while current_node.child_nodes.len() != 0 {
|
||||
// We need to choose a random child based on weights
|
||||
let weighted = WeightedIndex::new(
|
||||
current_node
|
||||
.child_nodes
|
||||
.iter()
|
||||
.map(|(_, node)| node.child_count),
|
||||
)
|
||||
.expect("distribution creation should be valid");
|
||||
|
||||
// now that we have the node we descend by respecting the probabilities
|
||||
while current_node.child_nodes.len() != 0 && current_node.child_count > 0 {
|
||||
println!("Descending into node {}", str);
|
||||
let max = current_node.child_count;
|
||||
let random_number = thread_rng().gen_range(0..max);
|
||||
let mut increment = 0;
|
||||
let (key, node) = current_node
|
||||
.child_nodes
|
||||
.iter()
|
||||
.nth(weighted.sample(&mut rng))
|
||||
.expect("choosed value did not exist");
|
||||
println!("waling into node {}", key);
|
||||
|
||||
let mut did_change = false;
|
||||
// find node corresponding to the node
|
||||
for (_, node) in ¤t_node.child_nodes {
|
||||
if node.child_count + increment >= random_number {
|
||||
println!("changing node");
|
||||
current_node = node;
|
||||
did_change = true;
|
||||
break;
|
||||
} else {
|
||||
println!(
|
||||
"didn't change node: {}<{}",
|
||||
node.child_count + increment,
|
||||
random_number
|
||||
)
|
||||
}
|
||||
increment += node.child_count;
|
||||
}
|
||||
if did_change {
|
||||
if let Some(value) = current_node.value {
|
||||
println!("added {}", value.to_char());
|
||||
str += value.to_char();
|
||||
}
|
||||
} else {
|
||||
println!(
|
||||
"WARNING: DIDNT CHANGE NODE child_count={}",
|
||||
current_node.child_count
|
||||
)
|
||||
}
|
||||
// if this node is a final node, we have a probability of using it
|
||||
current_node = node;
|
||||
builder += key;
|
||||
|
||||
// If this node is final and has childrens
|
||||
if current_node.is_final && current_node.child_count > 0 {
|
||||
let random_number = thread_rng().gen_range(0..current_node.child_count);
|
||||
if random_number == 0 {
|
||||
// choose from current node or continue with childrens
|
||||
let weighted = WeightedIndex::new(&[1, current_node.child_count])
|
||||
.expect("distribution seems impossible");
|
||||
|
||||
if weighted.sample(&mut rng) == 0 {
|
||||
// we choosed this node!
|
||||
// stop adding other characters
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if str == "" {
|
||||
// If we only added
|
||||
if builder == "" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// selected word
|
||||
Some(str)
|
||||
Some(builder)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,19 +10,23 @@ import { request } from "undici";
|
|||
// `autofeur_db` service
|
||||
export const DB = process.env.DB || "http://localhost:3000";
|
||||
// nats broker for connecting to nova
|
||||
export const NATS = process.env.NATS || "localhost:4222";
|
||||
export const NATS = process.env.NATS || "192.168.0.17:4222";
|
||||
// rest endpoint for connecting to nova
|
||||
export const REST = process.env.REST || "http://localhost:8090/api";
|
||||
export const REST = process.env.REST || "http://192.168.0.17:8090/api";
|
||||
|
||||
/**
|
||||
* Completes a grapheme using the `autofeur_db` service.
|
||||
* @param grapheme Grapheme to complete
|
||||
* @returns Completed grapheme
|
||||
*/
|
||||
export const completeWord = (grapheme: string) =>
|
||||
request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`).then((x) =>
|
||||
x.body.text()
|
||||
);
|
||||
export const completeWord = (grapheme: string) => {
|
||||
let resp = request(`${DB}?grapheme=${encodeURIComponent(grapheme)}`);
|
||||
return resp.then((x) => {
|
||||
if (x.statusCode === 200) {
|
||||
return x.body.text();
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Cleans a sentence for usage with this program, strips unwanted chars
|
||||
|
|
|
@ -23,6 +23,8 @@ services:
|
|||
deep_phonemizer:
|
||||
build: deep_phonemizer
|
||||
restart: always
|
||||
ports:
|
||||
- 8000:8000
|
||||
volumes:
|
||||
- ./deep_phonemizer/assets/model.pt:/app/assets/model.pt
|
||||
nats:
|
||||
|
|
Loading…
Reference in a new issue