添加 inventory_manager.py

This commit is contained in:
spdis 2025-07-27 11:21:57 +08:00
parent 0f090d5089
commit ca4f910012

632
inventory_manager.py Normal file
View File

@ -0,0 +1,632 @@
import json
from decimal import Decimal
from datetime import datetime
import ijson
import copy
class InventoryManager:
"""库存管理系统核心类
核心特性
- 流式JSON处理支持GB级文件
- 自动快照机制每50条记录
- 乱序插入自动重计算
- 严格时间倒序维护
- 进度回调支持
"""
def __init__(self, data_path):
self.data_path = data_path
self.snapshot_interval = 50
self.need_calculate = 0 # 距离上次快照的记录数
# 检查数据文件是否存在,不存在则创建
import os
if not os.path.exists(self.data_path):
self.initialize_data_file()
def process_transaction(self, new_txn, progress_callback=None):
"""处理新交易记录(核心入口函数)
参数
new_txn -- 新交易记录字典
progress_callback -- 进度回调函数 callback(current, total, message)
返回
处理结果字典 {"success": bool, "message": str, "recalculated_count": int}
"""
try:
# 1. 验证和预处理新记录
if not self._validate_transaction(new_txn):
return {"success": False, "message": "交易记录格式错误", "recalculated_count": 0}
# 2. 添加daily_sequence
new_txn['daily_sequence'] = self._get_next_sequence(new_txn['time'])
# 3. 检查是否需要乱序插入处理
insert_position = self._find_insert_position(new_txn)
if insert_position == 0:
# 最新记录,直接插入
return self._insert_latest_transaction(new_txn, progress_callback)
else:
# 乱序插入,需要重计算
return self._insert_historical_transaction(new_txn, insert_position, progress_callback)
except Exception as e:
if "库存不足" in str(e):
return {"success": False, "message": "库存不足", "recalculated_count": 0}
return {"success": False, "message": f"处理失败: {str(e)}", "recalculated_count": 0}
def _validate_transaction(self, txn):
"""验证交易记录格式"""
required_fields = ['product', 'type', 'weight', 'time']
if txn['type'] == '入库':
required_fields.append('price')
for field in required_fields:
if field not in txn:
return False
# 验证数据类型
try:
Decimal(txn['weight'])
if txn['type'] == '入库':
Decimal(txn['price'])
datetime.strptime(txn['time'], '%Y-%m-%d')
except:
return False
return True
def _find_insert_position(self, new_txn):
"""找到新记录应该插入的位置(流式读取)"""
position = 0
new_date = new_txn['time']
new_seq = new_txn['daily_sequence']
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for i, txn in enumerate(data['transactions']):
if txn.get('is_snapshot'):
continue
txn_date = txn.get('time', '')
txn_seq = txn.get('daily_sequence', 0)
# 比较日期和序列号
if (new_date > txn_date or
(new_date == txn_date and new_seq > txn_seq)):
return position
position += 1
return position
except:
return 0
def _insert_latest_transaction(self, new_txn, progress_callback):
"""插入最新记录(无需重计算)"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 如果是出库,先计算价格
if new_txn['type'] == '出库':
new_txn['price'] = self.calculate_outbound_price(new_txn['product'], new_txn['time'])
# 更新产品库存
if not self._update_product_inventory(data['products'], new_txn):
return {"success": False, "message": "库存不足", "recalculated_count": 0}
# 插入记录
data['transactions'].insert(0, new_txn)
# 检查是否需要创建快照
self.need_calculate += 1
if self.need_calculate >= self.snapshot_interval:
snapshot = self._create_snapshot(data['products'])
data['transactions'].insert(0, snapshot)
self.need_calculate = 0
# 保存数据
with open(self.data_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return {"success": True, "message": "记录添加成功", "recalculated_count": 0}
except Exception as e:
raise Exception(f"插入最新记录失败: {str(e)}")
def _insert_historical_transaction(self, new_txn, insert_position, progress_callback):
"""插入历史记录并重计算"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 1. 找到基准快照
base_snapshot = self._find_base_snapshot(new_txn['time'], data['transactions'])
# 2. 如果是出库,需要在重计算过程中处理价格
# 这里不预先计算价格,因为需要基于重计算后的状态
# 3. 插入新记录到正确位置
data['transactions'].insert(insert_position, new_txn)
# 4. 删除插入位置之后的所有快照
self._remove_snapshots_after_position(data['transactions'], insert_position)
# 5. 重计算从基准快照开始的所有记录
recalculated_count = self._recalculate_from_snapshot(
data, base_snapshot, insert_position, progress_callback
)
# 6. 保存数据
with open(self.data_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"历史记录插入成功,重计算了{recalculated_count}条记录",
"recalculated_count": recalculated_count
}
except Exception as e:
raise Exception(f"插入历史记录失败: {str(e)}")
def _find_base_snapshot(self, target_date, transactions):
"""找到目标日期之前最近的快照"""
for txn in reversed(transactions):
if (txn.get('is_snapshot') and
txn.get('timestamp', '') <= target_date):
return txn
# 如果没有找到快照,返回空的产品状态
return {
'is_snapshot': True,
'timestamp': '1900-01-01',
'products': {}
}
def _remove_snapshots_after_position(self, transactions, position):
"""删除指定位置之后的所有快照"""
i = 0
while i < len(transactions):
if i > position and transactions[i].get('is_snapshot'):
transactions.pop(i)
else:
i += 1
def delete_transaction(self, transaction_index, progress_callback=None):
"""删除指定的交易记录并重计算
参数
transaction_index -- 要删除的交易记录在transactions列表中的索引
progress_callback -- 进度回调函数 callback(current, total, message)
返回
处理结果字典 {"success": bool, "message": str, "recalculated_count": int}
"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查索引是否有效
if transaction_index < 0 or transaction_index >= len(data['transactions']):
return {"success": False, "message": "无效的交易记录索引", "recalculated_count": 0}
# 获取要删除的交易记录
target_txn = data['transactions'][transaction_index]
# 不能删除快照记录
if target_txn.get('is_snapshot'):
return {"success": False, "message": "不能删除快照记录", "recalculated_count": 0}
# 找到删除记录之前最近的快照
base_snapshot = self._find_base_snapshot(target_txn['time'], data['transactions'][:transaction_index])
# 删除目标记录
deleted_txn = data['transactions'].pop(transaction_index)
# 删除该位置之后的所有快照
self._remove_snapshots_after_position(data['transactions'], transaction_index - 1)
# 重计算从基准快照开始的所有记录
recalculated_count = self._recalculate_from_snapshot(
data, base_snapshot, transaction_index - 1, progress_callback
)
# 保存数据
with open(self.data_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"交易记录删除成功,重计算了{recalculated_count}条记录",
"recalculated_count": recalculated_count
}
except Exception as e:
return {"success": False, "message": f"删除失败: {str(e)}", "recalculated_count": 0}
def find_transaction_index(self, product_name, transaction_time, transaction_type, weight):
"""根据交易记录的关键信息查找其在transactions列表中的索引
参数
product_name -- 产品名称
transaction_time -- 交易时间
transaction_type -- 交易类型入库/出库
weight -- 重量
返回
交易记录的索引如果未找到返回-1
"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
for i, txn in enumerate(data['transactions']):
if (not txn.get('is_snapshot') and
txn.get('product') == product_name and
txn.get('time') == transaction_time and
txn.get('type') == transaction_type and
abs(float(txn.get('weight', 0)) - float(weight)) < 0.00000001): # 浮点数比较
return i
return -1
except Exception as e:
print(f"查找交易记录索引失败: {str(e)}")
return -1
def delete_product(self, product_name, progress_callback=None):
"""删除产品及其所有相关交易记录
参数
product_name -- 要删除的产品名称
progress_callback -- 进度回调函数 callback(current, total, message)
返回
处理结果字典 {"success": bool, "message": str, "deleted_transactions": int}
"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 确保产品名是字符串类型
product_name_str = str(product_name)
# 检查产品是否存在
if product_name_str not in data["products"]:
return {"success": False, "message": f"未找到产品: {product_name_str}", "deleted_transactions": 0}
# 统计要删除的交易记录数量
deleted_count = 0
# 删除产品
del data["products"][product_name_str]
# 删除相关交易记录,保留快照记录
original_count = len(data["transactions"])
data["transactions"] = [t for t in data["transactions"]
if str(t.get("product")) != product_name_str or t.get("is_snapshot")]
deleted_count = original_count - len(data["transactions"])
# 如果删除了交易记录,需要重新计算所有快照
if deleted_count > 0:
# 删除所有快照,重新计算
data["transactions"] = [t for t in data["transactions"] if not t.get("is_snapshot")]
# 重新计算所有产品的库存状态
self._recalculate_all_products(data, progress_callback)
# 保存数据
with open(self.data_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return {
"success": True,
"message": f"产品删除成功,删除了{deleted_count}条相关交易记录",
"deleted_transactions": deleted_count
}
except Exception as e:
return {"success": False, "message": f"删除产品失败: {str(e)}", "deleted_transactions": 0}
def _recalculate_all_products(self, data, progress_callback):
"""重新计算所有产品的库存状态"""
# 重置所有产品状态
for product_name in data["products"]:
data["products"][product_name] = {
"total_weight": "0.00000000",
"avg_price": "0.00000000"
}
# 按时间正序排序所有交易记录
transactions = [t for t in data["transactions"] if not t.get("is_snapshot")]
transactions.sort(key=lambda x: (x.get('time', ''), x.get('daily_sequence', 0)))
# 重新处理所有交易记录
processed_count = 0
snapshot_count = 0
for i, txn in enumerate(transactions):
if progress_callback:
progress_callback(i + 1, len(transactions), f"重新计算交易记录: {txn.get('product', '')}")
if txn['type'] == '入库':
self._update_product_inventory(data["products"], txn)
elif txn['type'] == '出库':
# 基于当前重算状态计算出库价格,避免循环依赖
txn['price'] = self._calculate_outbound_price_from_products(data["products"], txn['product'], txn['time'])
if not self._update_product_inventory(data["products"], txn):
# 如果重计算过程中发现库存不足,则回滚并返回错误
raise Exception("重计算过程中发现库存不足")
processed_count += 1
snapshot_count += 1
# 每50条记录创建快照
if snapshot_count >= self.snapshot_interval:
snapshot = self._create_snapshot(data["products"])
# 插入快照到正确位置
insert_position = self._find_insert_position_for_snapshot(data['transactions'], snapshot)
data['transactions'].insert(insert_position, snapshot)
snapshot_count = 0
# 更新transactions列表为重新排序后的记录
data['transactions'] = [t for t in data['transactions'] if t.get('is_snapshot')] + transactions
# 重新按时间倒序排序
data['transactions'].sort(key=lambda x: (x.get('time', x.get('timestamp', '')), x.get('daily_sequence', 0)), reverse=True)
def _recalculate_from_snapshot(self, data, base_snapshot, start_position, progress_callback):
"""从快照开始重计算所有记录"""
# 初始化产品状态 - 从快照复制,确保所有当前存在的产品都有正确的初始状态
products = copy.deepcopy(base_snapshot.get('products', {}))
# 确保当前数据中存在但快照中不存在的产品也被正确初始化为0
for product_name in data.get('products', {}):
if product_name not in products:
products[product_name] = {
'total_weight': '0.00000000',
'avg_price': '0.00000000'
}
# 收集需要重计算的记录 - 从基准快照开始的所有记录
records_to_process = []
base_timestamp = base_snapshot.get('timestamp', '1900-01-01')
for i, txn in enumerate(data['transactions']):
if (not txn.get('is_snapshot') and
txn.get('time', '') >= base_timestamp):
records_to_process.append((i, txn))
# 按时间正序排序进行重计算,确保库存状态正确重建
records_to_process.sort(key=lambda x: (x[1].get('time', ''), x[1].get('daily_sequence', 0)))
total_records = len(records_to_process)
processed_count = 0
snapshot_count = 0
for original_index, txn in records_to_process:
# 更新进度
if progress_callback:
progress_callback(processed_count, total_records, f"重计算记录 {processed_count+1}/{total_records}")
# 处理交易记录
if txn['type'] == '入库':
self._update_product_inventory(products, txn)
elif txn['type'] == '出库':
# 基于当前重算状态计算出库价格,避免循环依赖
txn['price'] = self._calculate_outbound_price_from_products(products, txn['product'], txn['time'])
if not self._update_product_inventory(products, txn):
raise Exception("重计算过程中发现库存不足")
processed_count += 1
snapshot_count += 1
# 每50条记录创建快照
if snapshot_count >= self.snapshot_interval:
snapshot = self._create_snapshot(products)
# 插入快照到正确位置,保持时间倒序
insert_position = self._find_insert_position_for_snapshot(data['transactions'], snapshot)
data['transactions'].insert(insert_position, snapshot)
snapshot_count = 0
# 更新最终产品状态
data['products'] = products
return total_records
def _calculate_outbound_price_from_products(self, products, product_name, outbound_date):
"""基于当前产品状态计算出库价格"""
if product_name in products:
return products[product_name].get('avg_price', '0.00000000')
return '0.00000000'
def _update_product_inventory(self, products, txn):
"""更新产品库存"""
product_name = txn['product']
if product_name not in products:
products[product_name] = {
'total_weight': '0.00000000',
'avg_price': '0.00000000'
}
if txn['type'] == '入库':
self._update_average_price(
products[product_name],
txn['weight'],
txn['price']
)
elif txn['type'] == '出库':
current_weight = Decimal(products[product_name]['total_weight'])
out_weight = Decimal(txn['weight'])
if current_weight < out_weight:
# 库存不足,返回错误指示
return False
new_weight = current_weight - out_weight
products[product_name]['total_weight'] = f"{new_weight:.8f}"
return True
return True
def _update_average_price(self, product, weight, price):
"""更新产品加权平均价格"""
total_weight = Decimal(product['total_weight'])
avg_price = Decimal(product['avg_price'])
new_weight = Decimal(weight)
new_price = Decimal(price)
if total_weight > 0:
new_total = total_weight + new_weight
new_avg = (total_weight * avg_price + new_weight * new_price) / new_total
else:
new_total = new_weight
new_avg = new_price
product['total_weight'] = f"{new_total:.8f}"
product['avg_price'] = f"{new_avg:.8f}"
def _create_snapshot(self, products):
"""创建快照"""
return {
"is_snapshot": True,
"timestamp": datetime.now().strftime('%Y-%m-%d'),
"daily_sequence": 0,
"products": copy.deepcopy(products)
}
def _find_insert_position_for_snapshot(self, transactions, snapshot):
"""找到快照的正确插入位置,保持时间倒序"""
snapshot_time = snapshot['timestamp']
for i, txn in enumerate(transactions):
txn_time = txn.get('time', txn.get('timestamp', ''))
if txn_time < snapshot_time:
return i
return len(transactions)
def _get_next_sequence(self, timestamp):
"""获取当日下一个序列号"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
max_seq = 0
for txn in data['transactions']:
if txn.get('time') == timestamp and not txn.get('is_snapshot'):
max_seq = max(max_seq, txn.get('daily_sequence', 0))
return max_seq + 1
except:
return 1
def calculate_outbound_price(self, product_name, outbound_date):
"""计算出库价格:基于最近且日期较早的快照,然后按时间顺序用后续入库记录进行加权平均"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 找到基准快照
base_snapshot = self._find_base_snapshot(outbound_date, data['transactions'])
# 初始化产品状态 - 确保快照中不存在的产品重量和价格设置为0
snapshot_products = base_snapshot.get('products', {})
if product_name in snapshot_products:
product_data = snapshot_products[product_name]
current_weight = Decimal(product_data.get('total_weight', '0.00000000'))
current_avg_price = Decimal(product_data.get('avg_price', '0.00000000'))
else:
# 快照中不存在该产品设置为0
current_weight = Decimal('0.00000000')
current_avg_price = Decimal('0.00000000')
# 收集快照之后的所有相关交易(入库和出库)
base_timestamp = base_snapshot.get('timestamp', '1900-01-01')
relevant_transactions = []
for txn in data['transactions']:
if (not txn.get('is_snapshot') and
txn.get('product') == product_name and
txn.get('time', '') >= base_timestamp and
txn.get('time', '') <= outbound_date):
relevant_transactions.append(txn)
# 按时间正序排序,确保加权平均计算正确
relevant_transactions.sort(key=lambda x: (x.get('time', ''), x.get('daily_sequence', 0)))
# 依次进行加权平均计算,考虑入库和出库
for record in relevant_transactions:
if record['type'] == '入库':
record_weight = Decimal(record['weight'])
record_price = Decimal(record['price'])
if current_weight > 0:
new_total_weight = current_weight + record_weight
current_avg_price = (current_weight * current_avg_price + record_weight * record_price) / new_total_weight
current_weight = new_total_weight
else:
current_avg_price = record_price
current_weight = record_weight
elif record['type'] == '出库':
out_weight = Decimal(record['weight'])
current_weight -= out_weight
# 如果出库导致库存为负理论上不应该发生但为了健壮性可以将其设为0
if current_weight < 0:
current_weight = Decimal('0.00000000')
current_avg_price = Decimal('0.00000000')
return f"{current_avg_price:.8f}"
except Exception as e:
return '0.00000000'
def get_product_list(self):
"""获取产品列表(流式读取)"""
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data.get('products', {})
except:
return {}
def initialize_data_file(self):
"""初始化数据文件"""
# 创建初始空快照
initial_snapshot = {
"is_snapshot": True,
"timestamp": "1900-01-01",
"daily_sequence": 0,
"products": {}
}
initial_data = {
"$schema": "./inventory_schema.json",
"products": {},
"transactions": [initial_snapshot]
}
try:
with open(self.data_path, 'w', encoding='utf-8') as f:
json.dump(initial_data, f, indent=2, ensure_ascii=False)
except Exception as e:
raise Exception(f"初始化数据文件失败: {str(e)}")
if __name__ == '__main__':
manager = InventoryManager('inventory_data.json')
# 测试代码
test_txn = {
'product': 'TEST001',
'type': '入库',
'weight': '100.00000000',
'price': '10.00000000',
'note': '测试入库',
'time': '2025-01-15'
}
result = manager.process_transaction(test_txn)
print(result)